Set up and Configuration

We load the dataset and display a few columns for the first five rows

In [1]:
import numpy as np
from IPython.display import display, HTML
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import pandas as pd
pd.options.display.max_columns = None

#df = pd.read_csv('S:/MHS_Pathfinder/bipolar_prediction/GOLD_&_AURUM_15_04_20.csv', low_memory=False)
df = pd.read_csv('clean_dataset.csv', low_memory=False)
display(df.head())
patid pracid diagnosis_date sex yob first_reg_date transfer_out_date death_date cohort_start cohort_end end_reason exposure incident_script dob exposure_end exposure_start suitable responder2 response2_1 adhd_date adhd alcohol_date alcohol asthma_date asthma cannabis_date cannabis conduct_date conduct dermatitis_date dermatitis migraine_date migraine other_substance_misuse_date other_substance_misuse psychosis_date psychosis self_harm_date self_harm stress_date stress mania_date mania mania_type N_man_b4 depression_date depression N_dep_b4 symptom_to_exposure symptom_to_diagnosis dominant FH_BPD_date FH_BPD FH_psychosis_date FH_psychosis FH_depression_date FH_depression FH_NOS_date FH_NOS FH_anxiety_date FH_anxiety FH_suicide_date FH_suicide FH_LD_date FH_LD FH_substance_date FH_substance FH_any anxiety_date anxiety PD_date PD sleep_date sleep T2DM_date T2DM BMI_date BMI weight ethnicity_date year_exposure ex_time smoke_date CHD_date CHD relationship relationship_date diastolic BP_date systolic hypertension eGFR_date CKD3 LDL LDL_date hi_LDL HDL HDL_date lo_HDL TSH TSH_date thyroid_blood hypothyroid_date hypothyroid hypothyroid_combined ca ca_date hi_ca lo_ca source first_episode OCD_date OCD psych_FH_date first_date age_first_exposure age_first_diagnosis hyperthyroid_date hyperthyroid smoker cardiac_arrythmia_date cardiac_arrythmia Neurological_disorders_date Neurological_disorders Liver_disease_date Liver_disease HIV_AIDS_date HIV_AIDS Fluid_electrolyte_disorder_date Fluid_electrolyte_disorders Diabetes_uncomplicated_date Diabetes_uncomplicated Diabetes_organ_damage_date Diabetes_organ_damage Deficiency_anaemia_date Deficiency_anaemia Congestive_heart_failure_date Congestive_heart_failure Coagulopathy_date Coagulopathy Chronic_pulmonary_disease_date Chronic_pulmonary_disease Weight_loss_date Weight_loss Valvular_disease_date Valvular_disease RA_date RA Pulmonary_circulation_date Peripheral_vascular_date Peripheral_vascular Peptic_ulcer_date Peptic_ulcer first_AP_date first_MS_date first_li_date first_olan_date ap_b4 ap_duration ms_b4 ms_duration li_b4 olan_b4 SSRI first_SSRI_date last_SSRI_date SSRI_b4 SSRI_during TCA first_TCAs_date last_TCAs_date TCA_b4 TCA_during other_ADs first_other_ADs_date last_other_ADs_date other_AD_b4 other_AD_during any_AD_b4 any_AD_during Pulmonary_circulation ethnicity age_diagnosis age_first_reg age_transfer_out age_death age_adhd age_alcohol age_asthma age_cannabis age_conduct age_dermatitis age_migraine age_other_substance_misuse age_psychosis age_self_harm age_stress age_mania age_depression age_FH_BPD age_FH_psychosis age_FH_depression age_FH_NOS age_FH_anxiety age_FH_suicide age_FH_LD age_FH_substance age_anxiety age_PD age_sleep age_T2DM age_BMI age_ethnicity age_smoke age_CHD age_relationship age_BP age_eGFR age_LDL age_HDL age_TSH age_hypothyroid age_ca age_OCD age_psych_FH age_first age_hyperthyroid age_cardiac_arrythmia age_Neurological_disorders age_Liver_disease age_HIV_AIDS age_Fluid_electrolyte_disorder age_Diabetes_uncomplicated age_Diabetes_organ_damage age_Deficiency_anaemia age_Congestive_heart_failure age_Coagulopathy age_Chronic_pulmonary_disease age_Weight_loss age_Valvular_disease age_RA age_Pulmonary_circulation age_Peripheral_vascular age_Peptic_ulcer age_first_AP age_first_MS age_first_li age_first_olan age_first_SSRI age_last_SSRI age_first_TCAs age_last_TCAs age_first_other_ADs age_last_other_ADs
0 10000327 327 1999-11-01 female 1956-01-01 2009-06-30 NaN NaN 30jun2009 20jul2010 end f/u olanzapine 0 02jul1956 23jul2010 25jun2010 0 NaN 0.0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 2009-07-08 1 NaN 0 NaN 0 NaN 0 unclear 0 1984-10-18 1 1 25.683779 15.036277 depression NaN 0 NaN 0 NaN 0 NaN NaN NaN 0 NaN 0 NaN 0 NaN 0 0 1970-01-12 1 NaN NaN NaN 0 NaN 0 2009-06-30 24.060369 healthy weight 2009-06-30 NaN 0.076660 2009-06-30 NaN 0 0 NaN 82.0 2009-06-30 132.0 0 NaN 0 NaN NaN 0 NaN NaN 0.0 NaN NaN NaN NaN 0 NaN NaN NaN 0 0 GOLD depression NaN 0 NaN 1984-10-18 53.979465 43.331963 NaN 0 never smoker NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN NaN 0 NaN 0 NaN 2009-07-08 NaN 2009-07-08 0 NaN 1 0.963723 0 1 Citalopram hydrobromide 2009-07-09 2010-06-25 0 0 NaN NaN NaN 0 0 NaN NaN NaN 0 0 0.0 0.0 0 White 43.863014 53.531507 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 53.553425 NaN NaN NaN 28.816438 NaN NaN NaN NaN NaN NaN NaN NaN 14.041096 NaN NaN NaN 53.531507 53.531507 53.531507 NaN NaN 53.531507 NaN NaN NaN NaN NaN NaN NaN NaN 28.816438 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 53.553425 NaN 53.553425 53.556164 54.517808 NaN NaN NaN NaN
1 1000091720274 20274 2004-03-23 female 1923-01-01 1987-05-08 2015-01-22 2014-12-27 08may1987 22jan2015 died lithium 1 03jul1923 01feb1993 08dec1992 1 0.0 0.0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 1978-01-01 1 NaN 0 NaN 0 NaN 0 unclear 0 NaN 0 0 0.000000 11.288158 unclear NaN 0 NaN 0 NaN 0 NaN 0.0 NaN 0 NaN 0 NaN 0 NaN 0 0 2006-09-11 0 NaN 0.0 NaN 0 2001-06-12 0 NaN NaN healthy weight 2007-01-25 1992.0 0.150582 1993-10-13 NaN 0 0 NaN NaN 2005-10-05 NaN 0 2011-07-26 0 NaN 2005-05-20 0 NaN 2006-05-19 0.0 NaN 2001-10-09 normal NaN 0 0.0 2.08 2011-01-07 0 1 AURUM mania NaN 0 NaN 1992-12-08 69.434631 80.722794 NaN 0 ex smoker 2008-03-27 0 2008-05-14 0 NaN 0 NaN 0 NaN 0 2001-06-12 0 2011-07-05 0 2005-03-22 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 2001-02-26 0 NaN NaN 0 1964-01-01 1 1992-12-08 2009-06-17 1992-12-08 NaN 0 0.073922 0 5.475702 0 0 Sertraline hydrochloride 1996-07-11 2014-12-08 0 0 NaN NaN NaN 0 0 Mirtazapine 2006-11-01 2006-11-15 0 0 NaN NaN 0 White 81.279452 64.391781 92.120548 92.049315 NaN NaN NaN NaN NaN NaN NaN NaN 55.038356 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 83.750685 NaN NaN 78.49863 NaN 84.123288 70.830137 NaN NaN 82.816438 88.624658 82.438356 83.435616 78.824658 NaN 88.076712 NaN NaN 69.983562 NaN 85.293151 85.424658 NaN NaN NaN 78.498630 88.567123 82.276712 NaN NaN NaN NaN NaN 78.208219 NaN NaN 41.027397 69.983562 86.517808 69.983562 NaN 73.575342 91.997260 NaN NaN 83.890411 83.928767
2 1000107720274 20274 2014-04-03 female 1972-01-01 1999-01-15 NaN NaN 15jan1999 28aug2018 end f/u lithium 1 02jul1972 18mar2016 17feb2016 1 0.0 0.0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 1992-01-01 1 NaN 0 NaN 0 unclear 0 1992-01-01 1 1 24.128679 22.253252 depression NaN 0 NaN 0 NaN 0 NaN 0.0 NaN 0 NaN 0 NaN 0 NaN 0 0 2008-02-20 1 NaN 0.0 2004-08-09 1 NaN 0 2015-07-09 24.677023 healthy weight NaN 2016.0 0.082136 2006-02-08 NaN 0 0 NaN 83.0 2015-07-09 125.0 0 2016-01-12 0 3.70 2016-01-12 1 1.40 2015-07-16 0.0 3.17 2016-01-12 normal NaN 0 0.0 2.33 2016-01-12 0 0 AURUM depression NaN 0 NaN 1992-01-01 43.627651 41.752224 NaN 0 ex smoker NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN NaN 0 NaN 0 2013-11-11 2014-04-17 2016-02-17 NaN 1 3.671458 1 1.774127 0 0 Citalopram hydrobromide 2000-08-21 2009-06-30 1 0 Amitriptyline hydrochloride 2004-08-09 2004-08-09 1 0 Mirtazapine 2004-12-03 2018-08-02 0 1 NaN NaN 0 White 42.282192 27.057534 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 20.013699 NaN NaN 20.013699 NaN NaN NaN NaN NaN NaN NaN NaN 36.161644 NaN 32.627397 NaN 43.547945 NaN 34.128767 NaN NaN 43.547945 44.060274 44.060274 43.567123 44.060274 NaN 44.060274 NaN NaN 20.013699 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 41.890411 42.320548 44.158904 NaN 28.657534 37.520548 32.627397 32.627397 32.945205 46.616438
3 1000157020274 20274 1993-01-13 female 1936-01-01 1986-10-23 NaN NaN 01jan1987 28aug2018 end f/u lithium 1 02jul1936 15jan1996 13jan1993 1 1.0 1.0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 1993-01-13 1 mania 1 NaN 0 0 0.000000 0.000000 mania NaN 0 NaN 0 NaN 0 NaN 0.0 NaN 0 NaN 0 NaN 0 NaN 0 0 NaN 0 NaN 0.0 NaN 0 NaN 0 1993-01-20 25.260000 overweight 2006-10-11 1993.0 3.003422 1993-01-20 1985-01-01 1 0 NaN NaN 2018-04-13 NaN 0 2006-10-11 0 NaN 2010-07-22 0 NaN 2018-04-13 0.0 NaN 2003-03-18 normal NaN 0 0.0 2.59 2014-12-08 0 0 AURUM mania NaN 0 NaN 1993-01-13 56.533882 56.533882 NaN 0 current smoker NaN 0 NaN 0 NaN 0 NaN 0 2018-05-18 0 2013-09-11 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN NaN 0 NaN 0 NaN NaN 1993-01-13 NaN 0 NaN 0 NaN 0 0 NaN NaN NaN 0 0 NaN NaN NaN 0 0 Mirtazapine 2013-01-02 2018-08-15 0 0 NaN NaN 0 White 57.073973 50.843836 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 57.073973 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 57.093151 70.824658 57.093151 49.035616 NaN 82.336986 70.824658 74.605479 82.336986 67.254795 NaN 78.989041 NaN NaN 57.073973 NaN NaN NaN NaN NaN 82.432877 77.747945 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 57.073973 NaN NaN NaN NaN NaN 77.057534 82.676712
4 1000172020274 20274 2008-09-16 female 1987-01-01 2009-07-27 NaN NaN 27jul2009 28aug2018 end f/u olanzapine 1 03jul1987 27mar2014 10jan2014 1 0.0 0.0 NaN 0 NaN 0 1996-07-01 1 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 2008-07-22 1 NaN 0 2014-01-10 1 mania 1 2003-01-20 1 12 10.973306 5.656400 depression NaN 0 NaN 0 NaN 0 NaN 0.0 NaN 0 NaN 0 NaN 0 NaN 0 0 NaN 0 2014-06-25 0.0 2008-07-24 1 NaN 0 2013-11-19 23.011179 healthy weight 2008-10-31 2014.0 0.208077 1996-02-20 NaN 0 0 NaN 75.0 2013-12-11 127.0 0 2010-06-17 0 1.98 2012-05-21 0 1.38 2012-05-21 0.0 2.83 2013-10-07 normal NaN 0 0.0 2.18 2011-01-14 0 0 AURUM depression NaN 0 NaN 2003-01-20 26.524298 21.207392 NaN 0 ex smoker NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 2014-09-25 0 NaN 0 NaN 0 NaN 0 NaN 0 1996-07-01 1 NaN 0 NaN 0 NaN 0 NaN NaN 0 NaN 0 2008-09-02 NaN NaN 2014-01-10 1 0.213552 0 NaN 0 0 Citalopram hydrobromide 2003-01-20 2013-06-30 1 0 NaN NaN NaN 0 0 NaN NaN NaN 0 0 NaN NaN 0 Other 21.723288 22.583562 NaN NaN NaN NaN 9.50411 NaN NaN NaN NaN NaN NaN 21.569863 NaN 27.043836 16.063014 NaN NaN NaN NaN NaN NaN NaN NaN NaN 27.49863 21.575342 NaN 26.901370 21.846575 9.142466 NaN NaN 26.961644 23.473973 25.402740 25.402740 26.783562 NaN 24.052055 NaN NaN 16.063014 NaN NaN NaN NaN NaN NaN 27.750685 NaN NaN NaN NaN 9.50411 NaN NaN NaN NaN NaN NaN 21.684932 NaN NaN 27.043836 16.063014 26.512329 NaN NaN NaN NaN

NB: Because lo_HDL has missing values, pandas interprets the whole column as float by default.

Characterisation

In [2]:
# What columns do we have?
print('\n'.join(sorted(df.columns.to_list())))
BMI
BMI_date
BP_date
CHD
CHD_date
CKD3
Chronic_pulmonary_disease
Chronic_pulmonary_disease_date
Coagulopathy
Coagulopathy_date
Congestive_heart_failure
Congestive_heart_failure_date
Deficiency_anaemia
Deficiency_anaemia_date
Diabetes_organ_damage
Diabetes_organ_damage_date
Diabetes_uncomplicated
Diabetes_uncomplicated_date
FH_BPD
FH_BPD_date
FH_LD
FH_LD_date
FH_NOS
FH_NOS_date
FH_anxiety
FH_anxiety_date
FH_any
FH_depression
FH_depression_date
FH_psychosis
FH_psychosis_date
FH_substance
FH_substance_date
FH_suicide
FH_suicide_date
Fluid_electrolyte_disorder_date
Fluid_electrolyte_disorders
HDL
HDL_date
HIV_AIDS
HIV_AIDS_date
LDL
LDL_date
Liver_disease
Liver_disease_date
N_dep_b4
N_man_b4
Neurological_disorders
Neurological_disorders_date
OCD
OCD_date
PD
PD_date
Peptic_ulcer
Peptic_ulcer_date
Peripheral_vascular
Peripheral_vascular_date
Pulmonary_circulation
Pulmonary_circulation_date
RA
RA_date
SSRI
SSRI_b4
SSRI_during
T2DM
T2DM_date
TCA
TCA_b4
TCA_during
TSH
TSH_date
Valvular_disease
Valvular_disease_date
Weight_loss
Weight_loss_date
adhd
adhd_date
age_BMI
age_BP
age_CHD
age_Chronic_pulmonary_disease
age_Coagulopathy
age_Congestive_heart_failure
age_Deficiency_anaemia
age_Diabetes_organ_damage
age_Diabetes_uncomplicated
age_FH_BPD
age_FH_LD
age_FH_NOS
age_FH_anxiety
age_FH_depression
age_FH_psychosis
age_FH_substance
age_FH_suicide
age_Fluid_electrolyte_disorder
age_HDL
age_HIV_AIDS
age_LDL
age_Liver_disease
age_Neurological_disorders
age_OCD
age_PD
age_Peptic_ulcer
age_Peripheral_vascular
age_Pulmonary_circulation
age_RA
age_T2DM
age_TSH
age_Valvular_disease
age_Weight_loss
age_adhd
age_alcohol
age_anxiety
age_asthma
age_ca
age_cannabis
age_cardiac_arrythmia
age_conduct
age_death
age_depression
age_dermatitis
age_diagnosis
age_eGFR
age_ethnicity
age_first
age_first_AP
age_first_MS
age_first_SSRI
age_first_TCAs
age_first_diagnosis
age_first_exposure
age_first_li
age_first_olan
age_first_other_ADs
age_first_reg
age_hyperthyroid
age_hypothyroid
age_last_SSRI
age_last_TCAs
age_last_other_ADs
age_mania
age_migraine
age_other_substance_misuse
age_psych_FH
age_psychosis
age_relationship
age_self_harm
age_sleep
age_smoke
age_stress
age_transfer_out
alcohol
alcohol_date
anxiety
anxiety_date
any_AD_b4
any_AD_during
ap_b4
ap_duration
asthma
asthma_date
ca
ca_date
cannabis
cannabis_date
cardiac_arrythmia
cardiac_arrythmia_date
cohort_end
cohort_start
conduct
conduct_date
death_date
depression
depression_date
dermatitis
dermatitis_date
diagnosis_date
diastolic
dob
dominant
eGFR_date
end_reason
ethnicity
ethnicity_date
ex_time
exposure
exposure_end
exposure_start
first_AP_date
first_MS_date
first_SSRI_date
first_TCAs_date
first_date
first_episode
first_li_date
first_olan_date
first_other_ADs_date
first_reg_date
hi_LDL
hi_ca
hypertension
hyperthyroid
hyperthyroid_date
hypothyroid
hypothyroid_combined
hypothyroid_date
incident_script
last_SSRI_date
last_TCAs_date
last_other_ADs_date
li_b4
lo_HDL
lo_ca
mania
mania_date
mania_type
migraine
migraine_date
ms_b4
ms_duration
olan_b4
other_AD_b4
other_AD_during
other_ADs
other_substance_misuse
other_substance_misuse_date
patid
pracid
psych_FH_date
psychosis
psychosis_date
relationship
relationship_date
responder2
response2_1
self_harm
self_harm_date
sex
sleep
sleep_date
smoke_date
smoker
source
stress
stress_date
suitable
symptom_to_diagnosis
symptom_to_exposure
systolic
thyroid_blood
transfer_out_date
weight
year_exposure
yob
In [3]:
# How many unique patients?
print("Total patients:", df.patid.nunique())
print(df['suitable'].value_counts(dropna=False))
Total patients: 38957
1    31518
0     7439
Name: suitable, dtype: int64

Now we only keep patients that are suitable for the inclusion analysis (suitable==1). Those are defined as patients with >2 years of follow-up after exposure_start.

In [4]:
df_old = df.copy()
df = df.loc[df.suitable==1]
print("New total patients:", df.patid.nunique())
New total patients: 31518
In [5]:
# Count values for a few important columns
for item in ['source', 'exposure', 'suitable', 'response2_1', 'responder2', 'symptom_to_exposure', 'exposure_end']:
    print(item)
    print(df[item].value_counts(dropna=False))
    print(df[item].describe())
    print()
source
AURUM    20910
GOLD     10608
Name: source, dtype: int64
count     31518
unique        2
top       AURUM
freq      20910
Name: source, dtype: object

exposure
lithium       19106
olanzapine    12412
Name: exposure, dtype: int64
count       31518
unique          2
top       lithium
freq        19106
Name: exposure, dtype: object

suitable
1    31518
Name: suitable, dtype: int64
count    31518.0
mean         1.0
std          0.0
min          1.0
25%          1.0
50%          1.0
75%          1.0
max          1.0
Name: suitable, dtype: float64

response2_1
0.0    14785
1.0    11848
NaN     4885
Name: response2_1, dtype: int64
count    26633.000000
mean         0.444862
std          0.496960
min          0.000000
25%          0.000000
50%          0.000000
75%          1.000000
max          1.000000
Name: response2_1, dtype: float64

responder2
0.0    19670
1.0    11848
Name: responder2, dtype: int64
count    31518.000000
mean         0.375912
std          0.484365
min          0.000000
25%          0.000000
50%          0.000000
75%          1.000000
max          1.000000
Name: responder2, dtype: float64

symptom_to_exposure
 0.000000     3251
 0.038330       46
 0.002738       45
 0.019165       42
 0.093087       38
              ... 
 4.098563        1
 23.208761       1
 20.591375       1
-15.641341       1
 14.631075       1
Name: symptom_to_exposure, Length: 11420, dtype: int64
count    31518.000000
mean        10.186286
std         12.650202
min        -30.031485
25%          0.813142
50%          6.557153
75%         15.629706
max        148.963730
Name: symptom_to_exposure, dtype: float64

exposure_end
04mar2019    71
01mar2019    65
26feb2019    64
19feb2019    59
18feb2019    58
             ..
30mar2013     1
19jun1989     1
24sep2007     1
24jan2014     1
09dec1999     1
Name: exposure_end, Length: 7796, dtype: int64
count         31518
unique         7796
top       04mar2019
freq             71
Name: exposure_end, dtype: object

Table 1

In [31]:
features13 = ['age_first_exposure', 'age_first_diagnosis', 'symptom_to_exposure', 'psychosis', 'depression', 'mania', 'dominant', 'sex', 'FH_BPD', 'FH_depression', 'FH_psychosis', 'weight', 'self_harm']
In [34]:
features = """Total, N
Age at diagnosis, median (IQR)
Age at medication initiation, median (IQR)
Years between diagnosis and exposure, median (IQR)
Female, n (%)
First presentation mania, n (%)
First presentation depression, n (%)
Depression dominant, n (%)
Psychotic experiences, n (%)
Self-harm history, n (%)
Smoker, n (%)
Family history for bipolar disorder, n (%)
Family history for depression, n (%)
Family history for psychosis, n (%)
Overweight or obese, n (%)"""
featurs_list = features.split('\n')
Table1 = pd.DataFrame(featurs_list, columns=['Features'])


for exp, exp_df in df.groupby('exposure'):
    data_list = []
    data_list.append(str(len(exp_df)))
    for feature in ['age_first_diagnosis', 'age_first_exposure', 'symptom_to_exposure']:
            Q3 = np.quantile(exp_df[feature], 0.75)
            median = np.quantile(exp_df[feature], 0.5)
            median = exp_df[feature].median()
            Q1 = np.quantile(exp_df[feature], 0.25)
            IQR = Q3 - Q1
            data_list.append("{:.2f} ({:.2f})".format(median, IQR))
    for feature in ['sex', 'mania', 'depression', 'dominant', 'psychosis', 'self_harm', 'smoker', 'FH_BPD', 'FH_depression', 'FH_psychosis', 'weight']:
        if feature == 'sex':
            sum = (exp_df[feature] == 'female').sum()
        elif feature == 'dominant':
            sum = (exp_df[feature] == 'depression').sum()
        elif feature == 'smoker':
            sum = exp_df[feature].isin(['ex smoker', 'current smoker']).sum()
        elif feature == 'weight':
            sum = exp_df[feature].isin(['overweight', 'obese']).sum()
        else:
            sum = exp_df[feature].sum()
        data_list.append("{} ({:.2f} %)".format(sum,sum/len(exp_df)*100))
    Table1[exp] = data_list

display(Table1)
Features lithium olanzapine
0 Total, N 19106 12412
1 Age at diagnosis, median (IQR) 40.82 (24.13) 39.08 (22.36)
2 Age at medication initiation, median (IQR) 46.51 (23.20) 42.47 (23.04)
3 Years between diagnosis and exposure, median (... 7.37 (15.60) 5.45 (13.35)
4 Female, n (%) 11526 (60.33 %) 6858 (55.25 %)
5 First presentation mania, n (%) 4705 (24.63 %) 3498 (28.18 %)
6 First presentation depression, n (%) 11233 (58.79 %) 7453 (60.05 %)
7 Depression dominant, n (%) 9662 (50.57 %) 6457 (52.02 %)
8 Psychotic experiences, n (%) 4939 (25.85 %) 3806 (30.66 %)
9 Self-harm history, n (%) 2362 (12.36 %) 1822 (14.68 %)
10 Smoker, n (%) 13970 (73.12 %) 8796 (70.87 %)
11 Family history for bipolar disorder, n (%) 402 (2.10 %) 136 (1.10 %)
12 Family history for depression, n (%) 375 (1.96 %) 245 (1.97 %)
13 Family history for psychosis, n (%) 111 (0.58 %) 141 (1.14 %)
14 Overweight or obese, n (%) 6540 (34.23 %) 5038 (40.59 %)

Modification

Some values are negative. They should not be.

In [8]:
display(df.loc[df.symptom_to_exposure < 0, 'symptom_to_exposure'])
display(df.loc[df.age_first_diagnosis < 0, ['age_first_diagnosis', 'symptom_to_exposure']])
pd.set_option('display.max_rows', 100)
display(df.loc[(df.symptom_to_exposure>65) &  (df.symptom_to_exposure<80)])
74      -8.372348
79      -4.016427
103     -2.948665
158     -7.860370
176     -0.032854
           ...   
38872   -5.037645
38877   -0.016427
38880   -5.653662
38924   -0.977413
38937   -3.529090
Name: symptom_to_exposure, Length: 1213, dtype: float64
age_first_diagnosis symptom_to_exposure
2067 -77.500343 106.620120
6260 -68.498291 104.038330
6771 -0.134155 24.533880
7937 -87.498970 134.294310
8706 -98.499657 143.857640
8709 -109.500340 141.295000
9641 -46.499657 89.848053
14944 -58.502396 102.907600
15065 -49.503078 104.700890
15412 -90.499657 145.867220
15525 -117.500340 148.963730
15585 -99.498970 134.179340
15849 -56.498287 99.531830
15997 -73.500343 139.493500
16216 -0.314853 23.898699
17417 -96.498291 139.028060
23681 -0.114990 57.355236
24757 -22.499659 97.952087
24810 -32.498287 97.932922
24849 -35.498974 97.092400
25084 -0.432580 31.414101
25134 -30.499659 98.453117
31628 -50.502396 102.009580
31843 -29.500341 95.737167
34554 -60.498287 136.251880
37134 -0.156057 60.238194
patid pracid diagnosis_date sex yob first_reg_date transfer_out_date death_date cohort_start cohort_end end_reason exposure incident_script dob exposure_end exposure_start suitable responder2 response2_1 adhd_date adhd alcohol_date alcohol asthma_date asthma cannabis_date cannabis conduct_date conduct dermatitis_date dermatitis migraine_date migraine other_substance_misuse_date other_substance_misuse psychosis_date psychosis self_harm_date self_harm stress_date stress mania_date mania mania_type N_man_b4 depression_date depression N_dep_b4 symptom_to_exposure symptom_to_diagnosis dominant FH_BPD_date FH_BPD FH_psychosis_date FH_psychosis FH_depression_date FH_depression FH_NOS_date FH_NOS FH_anxiety_date FH_anxiety FH_suicide_date FH_suicide FH_LD_date FH_LD FH_substance_date FH_substance FH_any anxiety_date anxiety PD_date PD sleep_date sleep T2DM_date T2DM BMI_date BMI weight ethnicity_date year_exposure ex_time smoke_date CHD_date CHD relationship relationship_date diastolic BP_date systolic hypertension eGFR_date CKD3 LDL LDL_date hi_LDL HDL HDL_date lo_HDL TSH TSH_date thyroid_blood hypothyroid_date hypothyroid hypothyroid_combined ca ca_date hi_ca lo_ca source first_episode OCD_date OCD psych_FH_date first_date age_first_exposure age_first_diagnosis hyperthyroid_date hyperthyroid smoker cardiac_arrythmia_date cardiac_arrythmia Neurological_disorders_date Neurological_disorders Liver_disease_date Liver_disease HIV_AIDS_date HIV_AIDS Fluid_electrolyte_disorder_date Fluid_electrolyte_disorders Diabetes_uncomplicated_date Diabetes_uncomplicated Diabetes_organ_damage_date Diabetes_organ_damage Deficiency_anaemia_date Deficiency_anaemia Congestive_heart_failure_date Congestive_heart_failure Coagulopathy_date Coagulopathy Chronic_pulmonary_disease_date Chronic_pulmonary_disease Weight_loss_date Weight_loss Valvular_disease_date Valvular_disease RA_date RA Pulmonary_circulation_date Peripheral_vascular_date Peripheral_vascular Peptic_ulcer_date Peptic_ulcer first_AP_date first_MS_date first_li_date first_olan_date ap_b4 ap_duration ms_b4 ms_duration li_b4 olan_b4 SSRI first_SSRI_date last_SSRI_date SSRI_b4 SSRI_during TCA first_TCAs_date last_TCAs_date TCA_b4 TCA_during other_ADs first_other_ADs_date last_other_ADs_date other_AD_b4 other_AD_during any_AD_b4 any_AD_during Pulmonary_circulation ethnicity age_diagnosis age_first_reg age_transfer_out age_death age_adhd age_alcohol age_asthma age_cannabis age_conduct age_dermatitis age_migraine age_other_substance_misuse age_psychosis age_self_harm age_stress age_mania age_depression age_FH_BPD age_FH_psychosis age_FH_depression age_FH_NOS age_FH_anxiety age_FH_suicide age_FH_LD age_FH_substance age_anxiety age_PD age_sleep age_T2DM age_BMI age_ethnicity age_smoke age_CHD age_relationship age_BP age_eGFR age_LDL age_HDL age_TSH age_hypothyroid age_ca age_OCD age_psych_FH age_first age_hyperthyroid age_cardiac_arrythmia age_Neurological_disorders age_Liver_disease age_HIV_AIDS age_Fluid_electrolyte_disorder age_Diabetes_uncomplicated age_Diabetes_organ_damage age_Deficiency_anaemia age_Congestive_heart_failure age_Coagulopathy age_Chronic_pulmonary_disease age_Weight_loss age_Valvular_disease age_RA age_Pulmonary_circulation age_Peripheral_vascular age_Peptic_ulcer age_first_AP age_first_MS age_first_li age_first_olan age_first_SSRI age_last_SSRI age_first_TCAs age_last_TCAs age_first_other_ADs age_last_other_ADs
5883 13787565 565 2001-11-14 male 1933-01-01 2001-02-05 NaN NaN 05feb2001 19nov2014 end f/u lithium 0 03jul1933 16sep2003 30aug2002 1 0.0 NaN NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 2004-09-24 0 NaN 0 NaN 0 2001-11-14 1 1993-10-15 1 NaN 0 2001-11-14 1 mania+psychoses 1 1933-01-28 1 10 69.585213 68.793976 depression NaN 0 NaN 0 NaN 0 NaN NaN NaN 0 NaN 0 NaN 0 NaN 0 0 1933-01-28 1 1983-04-06 1.0 2001-09-22 1 NaN 0 2001-02-09 22.737589 healthy weight NaN NaN 1.045859 2003-10-22 2008-11-20 0 0 NaN 82.0 2002-05-22 150.0 1 2010-04-16 0 NaN NaN 0 NaN NaN 0.0 0.57 2001-11-05 NaN NaN 0 NaN NaN NaN 0 0 GOLD depression NaN 0 NaN 1933-01-28 69.158112 68.366875 NaN 0 never smoker NaN 0 2006-02-27 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN NaN 0 NaN 0 2001-09-28 NaN 2001-11-16 2003-09-16 1 0.919918 0 NaN 1 0 Fluoxetine hydrochloride 2004-10-27 2007-08-06 0 0 NaN NaN NaN 0 0 Mirtazapine 2001-02-16 2004-10-21 0 1 0.0 1.0 0 White 68.915068 68.142466 NaN NaN NaN NaN NaN NaN NaN 71.778082 NaN NaN 68.915068 60.827397 NaN 68.915068 0.073973 NaN NaN NaN NaN NaN NaN NaN NaN 0.073973 50.293151 68.769863 NaN 68.153425 NaN 70.852055 75.936986 NaN 69.432877 77.339726 NaN NaN 68.890411 NaN NaN NaN NaN 0.073973 NaN NaN 73.205479 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 68.786301 NaN 68.920548 70.753425 71.868493 74.643836 NaN NaN 68.172603 71.852055
25718 2842029 29 1965-01-01 female 1911-01-01 2002-01-29 2004-03-16 2004-02-14 29jan2002 16mar2004 died lithium 0 03jul1911 09mar2004 14feb2002 1 1.0 1.0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 1965-01-01 1 NaN 0 NaN 0 NaN 0 unclear 0 1930-01-01 1 2 72.120468 35.000683 depression NaN 0 NaN 0 NaN 0 NaN NaN NaN 0 NaN 0 NaN 0 NaN 0 0 NaN 0 NaN NaN NaN 0 NaN 0 NaN NaN healthy weight NaN NaN 2.064339 NaN NaN 0 0 NaN 80.0 2003-04-19 128.0 0 NaN 0 NaN NaN 0 NaN NaN 0.0 NaN NaN NaN NaN 0 NaN NaN NaN 0 0 GOLD depression NaN 0 NaN 1930-01-01 90.620125 53.500343 NaN 0 never smoker NaN 0 1998-01-01 1 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN NaN 0 NaN 0 NaN NaN 2002-02-14 NaN 0 NaN 0 NaN 0 0 Sertraline hydrochloride 2002-02-14 2003-02-06 0 1 NaN NaN NaN 0 0 Mirtazapine 2003-08-19 2003-11-28 0 0 0.0 1.0 0 White 54.038356 91.139726 93.268493 93.183562 NaN NaN NaN NaN NaN NaN NaN NaN 54.038356 NaN NaN NaN 19.013699 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 92.358904 NaN NaN NaN NaN NaN NaN NaN NaN 19.013699 NaN NaN 87.060274 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 91.183562 NaN 91.183562 92.161644 NaN NaN 92.693151 92.969863
30300 4704742 742 1928-01-01 female 1917-01-01 1980-10-29 2006-12-01 2006-12-01 01jan1987 01dec2006 died lithium 1 03jul1917 05aug1997 23oct1996 1 0.0 0.0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 1976-06-01 1 1981-06-06 1 NaN 0 1976-06-01 1 mania+psychoses 1 NaN 0 0 68.810402 0.000000 mania NaN 0 NaN 0 NaN 0 NaN NaN NaN 0 NaN 0 NaN 0 NaN 0 0 NaN 0 NaN NaN 1993-06-01 1 NaN 0 NaN NaN healthy weight NaN NaN 0.783025 2004-11-11 2002-12-10 0 1 NaN 80.0 1999-12-07 140.0 0 2003-08-21 0 NaN NaN 0 NaN NaN 0.0 NaN NaN NaN 1974-01-01 1 NaN NaN NaN 0 0 GOLD mania NaN 0 NaN 1928-01-01 79.307327 10.496920 NaN 0 ex smoker NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 2002-12-10 0 NaN 0 NaN 0 NaN 0 1997-09-01 0 NaN 0 NaN NaN 0 NaN 0 NaN NaN 1996-10-23 NaN 0 NaN 0 NaN 0 0 Sertraline hydrochloride 1997-11-03 2001-12-10 0 0 Clomipramine hydrochloride 2002-05-24 2002-07-15 0 0 Venlafaxine hydrochloride 1996-10-15 2003-06-27 0 1 0.0 1.0 0 White 11.005479 63.868493 89.975342 89.975342 NaN NaN NaN NaN NaN NaN NaN NaN 59.454795 64.471233 NaN 59.454795 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 76.465753 NaN NaN NaN 87.920548 85.997260 NaN 82.986301 86.693151 NaN NaN NaN 57.038356 NaN NaN NaN 11.005479 NaN NaN NaN NaN NaN NaN NaN NaN NaN 85.99726 NaN NaN NaN 80.720548 NaN NaN NaN NaN NaN NaN 79.863014 NaN 80.893151 84.997260 85.449315 85.591781 79.841096 86.542466
32956 6177329 329 1936-11-24 male 1936-01-01 1991-01-28 2009-03-12 NaN 28jan1991 12mar2009 started AP lithium 1 02jul1936 14may2004 16apr2004 1 0.0 0.0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 1936-11-24 1 NaN 0 NaN 0 NaN 0 unclear 0 2006-01-23 0 0 67.392197 0.000000 unclear NaN 0 NaN 0 NaN 0 NaN NaN NaN 0 NaN 0 NaN 0 NaN 0 0 NaN 0 NaN NaN NaN 0 NaN 0 2002-02-28 31.134583 obese NaN NaN 0.076660 2004-03-12 NaN 0 0 NaN 90.0 2004-03-12 156.0 0 2009-01-15 0 NaN NaN 0 NaN NaN 0.0 NaN NaN NaN NaN 0 NaN NaN NaN 0 0 GOLD mania NaN 0 NaN 1936-11-24 67.789185 0.396988 NaN 0 never smoker NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN NaN 0 NaN 0 2004-05-14 NaN 2004-04-16 NaN 0 0.284736 0 NaN 0 0 NaN NaN NaN 0 0 NaN NaN NaN 0 0 NaN NaN NaN 0 0 0.0 0.0 0 White 0.898630 55.112329 73.243836 NaN NaN NaN NaN NaN NaN NaN NaN NaN 0.898630 NaN NaN NaN 70.109589 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 66.205479 NaN 68.241096 NaN NaN 68.241096 73.090411 NaN NaN NaN NaN NaN NaN NaN 0.898630 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 68.413699 NaN 68.336986 NaN NaN NaN NaN NaN NaN NaN
36808 8618764 764 1931-07-14 female 1931-01-01 1984-08-01 2003-11-21 2003-11-21 01jan1987 21nov2003 started AP lithium 1 03jul1931 23aug1999 07apr1999 1 0.0 0.0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 1931-07-14 1 mania 1 NaN 0 0 67.731689 0.000000 mania NaN 0 NaN 0 NaN 0 NaN NaN NaN 0 NaN 0 NaN 0 NaN 0 0 NaN 0 NaN NaN NaN 0 NaN 0 NaN NaN healthy weight NaN NaN 0.377823 1996-09-09 NaN 0 0 NaN 70.0 1995-11-07 120.0 0 NaN 0 NaN NaN 0 NaN NaN 0.0 NaN NaN NaN 1931-07-14 1 NaN NaN NaN 0 0 GOLD mania NaN 0 NaN 1931-07-14 67.761810 0.030116 NaN 0 current smoker NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN 0 NaN NaN 0 NaN 0 1999-08-23 NaN 1999-04-07 NaN 0 4.219028 0 NaN 0 0 NaN NaN NaN 0 0 NaN NaN NaN 0 0 NaN NaN NaN 0 0 0.0 0.0 0 White 0.531507 53.619178 72.936986 72.936986 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 0.531507 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 65.734247 NaN NaN 64.893151 NaN NaN NaN NaN 0.531507 NaN NaN NaN 0.531507 NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN 68.687671 NaN 68.309589 NaN NaN NaN NaN NaN NaN NaN

We modify those illicit values

In [9]:
df.loc[df.symptom_to_exposure<0, 'symptom_to_exposure'] = 0
df.loc[df.age_first_diagnosis<0, 'age_first_diagnosis'] = np.nan
df.loc[df.symptom_to_exposure>=80, 'symptom_to_exposure'] = np.nan

Target Selection

In [10]:
target_multiclass = df.response2_1.replace(1, 2).fillna(1)
df['target_multiclass'] = target_multiclass

target_lithium2y = (df.exposure=='lithium') & (df.response2_1==1)
df['target_lithium2y'] = target_lithium2y.astype(int)

target_exposure = (df.exposure=='olanzapine')
df['target_exposure'] = target_exposure.astype(int)
targets = {
    'resp': 'response2_1',
    'exp': 'target_exposure',  # 0=lithium 1=olanzapine
    'multi': 'target_multiclass',
    'lithium2y': 'target_lithium2y'
}

for target in targets:
    print(df[targets[target]].value_counts())
    print()
0.0    14785
1.0    11848
Name: response2_1, dtype: int64

0    19106
1    12412
Name: target_exposure, dtype: int64

0.0    14785
2.0    11848
1.0     4885
Name: target_multiclass, dtype: int64

0    23501
1     8017
Name: target_lithium2y, dtype: int64

Feature Selection

We have a lot of "age" features. The ones that have few missing values are good candidates: let's keep only the ages with less than 5000 missing values.

In [11]:
age_columns = [col for col in df.columns if col.startswith('age_')]
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    age_df = df[age_columns].isna().sum().sort_values(ascending=True)
    display(age_df)
features_age = age_df[age_df < 5000].index.to_list()
age_first_exposure                    0
age_first_reg                         0
age_diagnosis                        21
age_first_diagnosis                  26
age_first                           143
age_smoke                           968
age_BP                             3102
age_eGFR                           5370
age_TSH                            9770
age_BMI                           10303
age_first_li                      10466
age_first_AP                      11000
age_depression                    11533
age_ethnicity                     12565
age_last_SSRI                     12704
age_first_SSRI                    12722
age_HDL                           14024
age_first_olan                    14456
age_transfer_out                  15942
age_LDL                           16581
age_ca                            16668
age_anxiety                       17770
age_first_MS                      17786
age_last_TCAs                     18235
age_first_TCAs                    18258
age_psychosis                     19271
age_last_other_ADs                19447
age_first_other_ADs               19451
age_sleep                         21187
age_mania                         21830
age_Chronic_pulmonary_disease     23645
age_self_harm                     24530
age_dermatitis                    24955
age_Diabetes_uncomplicated        25385
age_stress                        25915
age_asthma                        25949
age_Neurological_disorders        26065
age_T2DM                          26324
age_alcohol                       26479
age_death                         26786
age_hypothyroid                   27231
age_PD                            28117
age_CHD                           28361
age_relationship                  28484
age_migraine                      28782
age_other_substance_misuse        28977
age_Deficiency_anaemia            29025
age_cardiac_arrythmia             29594
age_Weight_loss                   29594
age_Fluid_electrolyte_disorder    29903
age_Diabetes_organ_damage         30053
age_RA                            30419
age_Congestive_heart_failure      30524
age_OCD                           30601
age_Peptic_ulcer                  30629
age_Peripheral_vascular           30642
age_cannabis                      30764
age_Liver_disease                 30781
age_Valvular_disease              30950
age_psych_FH                      31014
age_FH_NOS                        31052
age_Pulmonary_circulation         31056
age_hyperthyroid                  31077
age_FH_depression                 31108
age_Coagulopathy                  31208
age_adhd                          31217
age_FH_BPD                        31234
age_FH_psychosis                  31267
age_conduct                       31275
age_HIV_AIDS                      31289
age_FH_suicide                    31383
age_FH_substance                  31398
age_FH_anxiety                    31492
age_FH_LD                         31509
dtype: int64

We now remove features that do not make sense from a clinical perspective, possibly confusing the learning process

In [12]:
features_age.remove('age_BP')
features_age.remove('age_first')
features_age.remove('age_diagnosis')
features_age.remove('age_smoke')
print('age features:', features_age)
age features: ['age_first_exposure', 'age_first_reg', 'age_first_diagnosis']

We now select a good list of features, agnostic as well as informed

In [13]:
# list of agnostic features (to complement with age features)
features_agnostic = """adhd
FH_suicide
mania
psychosis
relationship
self_harm
sex
sleep
smoker
T2DM
OCD
migraine
hypothyroid
CHD
other_substance_misuse
cannabis
alcohol
depression
FH_anxiety
FH_any
FH_BPD
FH_depression
FH_LD
FH_psychosis
N_dep_b4
N_man_b4
anxiety
stress
hi_LDL
lo_HDL
weight
CKD3
symptom_to_exposure
dominant"""



# list of informed features:
features_informed = """age_first_exposure
age_first_diagnosis
symptom_to_exposure
psychosis
depression
mania
dominant
sex
FH_BPD
FH_depression
FH_psychosis
weight
self_harm
cannabis
anxiety
stress
sleep
other_substance_misuse
relationship
OCD
adhd
smoker
alcohol
FH_suicide
hi_LDL
lo_HDL
CKD3
T2DM
migraine
hypothyroid
CHD
FH_anxiety
FH_any
FH_LD"""



shaky_features="""cannabis
anxiety
stress
sleep
other_substance_misuse
relationship
OCD
adhd
smoker
alcohol
FH_suicide
hi_LDL
lo_HDL
CKD3
T2DM
migraine
hypothyroid
CHD
FH_anxiety
FH_any
FH_LD"""
shaky_features = [ feature.strip() for feature in shaky_features.split('\n') ]
In [14]:
features = {
    # informed
    '34': [ feature.strip() for feature in features_informed.split('\n') ],
    # agnostic
    '37' : [ feature.strip() for feature in features_agnostic.split("\n") ] + features_age
}
features.update({
    '13' : [feature for feature in features['34'] if feature not in shaky_features]
})

print(len(features['34']), 'informed features:', features['34'])
print(len(features['37']), 'agnostic features:', features['37'])
print(len(features['13']), 'important features:', features['13'])
print(len(shaky_features), 'shaky features:', shaky_features)
34 informed features: ['age_first_exposure', 'age_first_diagnosis', 'symptom_to_exposure', 'psychosis', 'depression', 'mania', 'dominant', 'sex', 'FH_BPD', 'FH_depression', 'FH_psychosis', 'weight', 'self_harm', 'cannabis', 'anxiety', 'stress', 'sleep', 'other_substance_misuse', 'relationship', 'OCD', 'adhd', 'smoker', 'alcohol', 'FH_suicide', 'hi_LDL', 'lo_HDL', 'CKD3', 'T2DM', 'migraine', 'hypothyroid', 'CHD', 'FH_anxiety', 'FH_any', 'FH_LD']
37 agnostic features: ['adhd', 'FH_suicide', 'mania', 'psychosis', 'relationship', 'self_harm', 'sex', 'sleep', 'smoker', 'T2DM', 'OCD', 'migraine', 'hypothyroid', 'CHD', 'other_substance_misuse', 'cannabis', 'alcohol', 'depression', 'FH_anxiety', 'FH_any', 'FH_BPD', 'FH_depression', 'FH_LD', 'FH_psychosis', 'N_dep_b4', 'N_man_b4', 'anxiety', 'stress', 'hi_LDL', 'lo_HDL', 'weight', 'CKD3', 'symptom_to_exposure', 'dominant', 'age_first_exposure', 'age_first_reg', 'age_first_diagnosis']
13 important features: ['age_first_exposure', 'age_first_diagnosis', 'symptom_to_exposure', 'psychosis', 'depression', 'mania', 'dominant', 'sex', 'FH_BPD', 'FH_depression', 'FH_psychosis', 'weight', 'self_harm']
21 shaky features: ['cannabis', 'anxiety', 'stress', 'sleep', 'other_substance_misuse', 'relationship', 'OCD', 'adhd', 'smoker', 'alcohol', 'FH_suicide', 'hi_LDL', 'lo_HDL', 'CKD3', 'T2DM', 'migraine', 'hypothyroid', 'CHD', 'FH_anxiety', 'FH_any', 'FH_LD']
In [15]:
print("Agnostic features that are not informed features:", list(set(features['37']) - set(features['34'])))
print("Informed features that are not agnostic features:", list(set(features['34']) - set(features['37'])))
Agnostic features that are not informed features: ['N_man_b4', 'N_dep_b4', 'age_first_reg']
Informed features that are not agnostic features: []

Preparation of the dataframes (samples=X and target=y)

Here we prepare the X dataframe (samples with their features) and the y dataframe (target to predict, for each sample). We also fix the types of a few features to make sure the algorithms will interpret them correctly.

In [16]:
from sklearn import preprocessing, metrics

def prepare(features, target):

    # We first remove the samples with N/A for any of the features or label
    df_withoutNA = df.dropna(subset=(features + [target]))
    df_features = df_withoutNA[features]

    # Now we encode all string values in the features into an int value (we might use One Hot encoding later?)
    le = preprocessing.LabelEncoder()
    to_encode = [key for key in dict(df_features.dtypes) if dict(df.dtypes)[key] not in ['float64', 'int64']]
    new_df_features = df_features.copy()
    new_df_features.update(df_features[to_encode].apply(le.fit_transform))
    new_df_features[to_encode] = new_df_features[to_encode].astype(np.int64)

    # Because lo_HDL has missing values, it was interpreted by pandas with
    # floats. Now that we removed the missing values, we can interpret as int
    if 'lo_HDL' in features:
        new_df_features.lo_HDL = new_df_features.lo_HDL.astype(np.int64)

    # Now we have our X and y
    return new_df_features, df_withoutNA[target]

X = dict()
y = dict()
for feature in features:
    for target in targets:
        X[feature + '_' + target], y[feature + '_' + target] = prepare(features[feature] + ['exposure'], targets[target])
        if target != 'exp':
            y[feature + '_' + target] = y[feature + '_' + target].astype(np.int64)
        print(feature + '_' + target, len(X[feature + '_' + target]), 'rows')
34_resp 26509 rows
34_exp 31369 rows
34_multi 31369 rows
34_lithium2y 31369 rows
37_resp 26509 rows
37_exp 31369 rows
37_multi 31369 rows
37_lithium2y 31369 rows
13_resp 26510 rows
13_exp 31370 rows
13_multi 31370 rows
13_lithium2y 31370 rows
In [17]:
from scipy.stats import pearsonr, spearmanr
import seaborn as sns

# The difference between the X dictionnary and the X_dict dictionnary is that we drop 'exposure' in X-dict. We don't want it as a feature in X-dict but we need it in X to know the exposure of the patient.
X_dict = dict()
y_dict = dict()
for key in X:
    X_dict.update({
        key: X[key].drop('exposure', axis=1)
    })
    X_dict.update({
        'num_' + key: X_dict[key].loc[:, X_dict[key].dtypes == np.float64],
        'cat_' + key: X_dict[key].loc[:, X_dict[key].dtypes == np.int64],
    })
    X_dict.update({
        'bin_' + key: X_dict['cat_' + key].loc[:, X_dict['cat_' + key].nunique() == 2]
    })

    y_dict.update({
        key: y[key],
        'num_' + key: y[key],
        'cat_' + key: y[key],
        'bin_' + key: y[key],
    })

    # When we predict exposure or lithium2y, it doesn't make sense to predict separately exposures
    if '_exp' not in key and '_lithium2y' not in key:
        X_dict.update({
            'lit_' + key: X[key].loc[X[key].exposure == 0].drop('exposure', axis=1),
            'ola_' + key: X[key].loc[X[key].exposure == 1].drop('exposure', axis=1),
        })
        y_dict.update({
            'lit_' + key : y[key][X[key].loc[X[key].exposure == 0].index],
            'ola_' + key : y[key][X[key].loc[X[key].exposure == 1].index],
        })


print(X_dict.keys())
print(y_dict.keys())
dict_keys(['34_resp', 'num_34_resp', 'cat_34_resp', 'bin_34_resp', 'lit_34_resp', 'ola_34_resp', '34_exp', 'num_34_exp', 'cat_34_exp', 'bin_34_exp', '34_multi', 'num_34_multi', 'cat_34_multi', 'bin_34_multi', 'lit_34_multi', 'ola_34_multi', '34_lithium2y', 'num_34_lithium2y', 'cat_34_lithium2y', 'bin_34_lithium2y', '37_resp', 'num_37_resp', 'cat_37_resp', 'bin_37_resp', 'lit_37_resp', 'ola_37_resp', '37_exp', 'num_37_exp', 'cat_37_exp', 'bin_37_exp', '37_multi', 'num_37_multi', 'cat_37_multi', 'bin_37_multi', 'lit_37_multi', 'ola_37_multi', '37_lithium2y', 'num_37_lithium2y', 'cat_37_lithium2y', 'bin_37_lithium2y', '13_resp', 'num_13_resp', 'cat_13_resp', 'bin_13_resp', 'lit_13_resp', 'ola_13_resp', '13_exp', 'num_13_exp', 'cat_13_exp', 'bin_13_exp', '13_multi', 'num_13_multi', 'cat_13_multi', 'bin_13_multi', 'lit_13_multi', 'ola_13_multi', '13_lithium2y', 'num_13_lithium2y', 'cat_13_lithium2y', 'bin_13_lithium2y'])
dict_keys(['34_resp', 'num_34_resp', 'cat_34_resp', 'bin_34_resp', 'lit_34_resp', 'ola_34_resp', '34_exp', 'num_34_exp', 'cat_34_exp', 'bin_34_exp', '34_multi', 'num_34_multi', 'cat_34_multi', 'bin_34_multi', 'lit_34_multi', 'ola_34_multi', '34_lithium2y', 'num_34_lithium2y', 'cat_34_lithium2y', 'bin_34_lithium2y', '37_resp', 'num_37_resp', 'cat_37_resp', 'bin_37_resp', 'lit_37_resp', 'ola_37_resp', '37_exp', 'num_37_exp', 'cat_37_exp', 'bin_37_exp', '37_multi', 'num_37_multi', 'cat_37_multi', 'bin_37_multi', 'lit_37_multi', 'ola_37_multi', '37_lithium2y', 'num_37_lithium2y', 'cat_37_lithium2y', 'bin_37_lithium2y', '13_resp', 'num_13_resp', 'cat_13_resp', 'bin_13_resp', 'lit_13_resp', 'ola_13_resp', '13_exp', 'num_13_exp', 'cat_13_exp', 'bin_13_exp', '13_multi', 'num_13_multi', 'cat_13_multi', 'bin_13_multi', 'lit_13_multi', 'ola_13_multi', '13_lithium2y', 'num_13_lithium2y', 'cat_13_lithium2y', 'bin_13_lithium2y'])
In [18]:
display(X_dict['cat_13_lithium2y'])
psychosis depression mania dominant sex FH_BPD FH_depression FH_psychosis weight self_harm
1 1 0 0 2 0 0 0 0 0 0
2 0 1 0 0 0 0 0 0 0 1
3 0 0 1 1 0 0 0 0 2 0
4 0 1 1 0 0 0 0 0 0 1
5 1 1 0 0 0 0 0 0 0 0
... ... ... ... ... ... ... ... ... ... ...
38950 0 0 0 2 0 0 0 0 0 1
38951 1 1 0 0 0 0 0 0 0 0
38953 0 1 0 0 0 0 0 0 2 0
38954 0 0 1 1 0 0 0 0 0 0
38956 1 1 0 0 0 0 0 0 0 0

31370 rows × 10 columns

Feature Characterisation

Distribution and plots

In [19]:
%matplotlib inline
import seaborn as sns
sns.set_theme(style="white")
sns.set(font_scale=2)
from matplotlib.ticker import FormatStrFormatter

values = {0: df.age_first_exposure,
         1: df.age_first_diagnosis}

fig, ax = plt.subplots(ncols=2, nrows=1, sharey=True)
plt.subplots_adjust(hspace=0.6)
fig.set_size_inches(10, 6)
fig.tight_layout()
for i in values.keys():
    sns.distplot(values[i], kde=False, ax=ax[i], bins=range(100))
    # ax[i].set_yscale('log')
    ax[i].yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
    ax[i].set_xlim((0,100))
display(df.age_first_diagnosis)
/Users/fehmi/GoogleDrive/sics/projects/ucl/lithium/venv/lib/python3.9/site-packages/seaborn/distributions.py:2619: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
1        80.722794
2        41.752224
3        56.533882
4        21.207392
5        82.948662
           ...    
38950    55.764545
38951    46.135525
38953    28.867899
38954    31.498974
38956    38.529774
Name: age_first_diagnosis, Length: 31518, dtype: float64

Now let's check the number of missing values for each feature, and look at their entropy

In [20]:
import scipy.stats as st

def entropy3(labels, base=None):
  vc = pd.Series(labels).value_counts(normalize=False, sort=False)
  base = 2 if base is None else base
  #return -(vc * np.log(vc)/np.log(base)).sum()
  return st.entropy(vc)

nan_entropy = pd.concat([df[features['37']].isna().sum(), df[features['37']].apply(entropy3, axis=0)], axis=1)
nan_entropy.columns = ['N/A', 'Entropy']
display(nan_entropy)
display(df[features['37']])
N/A Entropy
adhd 0 0.033217
FH_suicide 0 0.027631
mania 0 0.573333
psychosis 0 0.590539
relationship 0 0.390328
self_harm 0 0.391580
sex 0 0.679209
sleep 0 0.379125
smoker 0 1.066039
T2DM 0 0.175462
OCD 0 0.096136
migraine 0 0.203815
hypothyroid 0 0.201039
CHD 0 0.145867
other_substance_misuse 0 0.191553
cannabis 0 0.078576
alcohol 0 0.224948
depression 0 0.675798
FH_anxiety 0 0.010781
FH_any 0 0.233201
FH_BPD 0 0.086404
FH_depression 0 0.096757
FH_LD 0 0.002616
FH_psychosis 0 0.046572
N_dep_b4 0 1.995139
N_man_b4 0 0.893913
anxiety 0 0.547383
stress 0 0.302052
hi_LDL 0 0.228529
lo_HDL 1 0.119001
weight 0 1.014700
CKD3 0 0.094266
symptom_to_exposure 143 8.074818
dominant 0 0.988668
age_first_exposure 0 9.559759
age_first_reg 0 9.567249
age_first_diagnosis 26 8.743020
adhd FH_suicide mania psychosis relationship self_harm sex sleep smoker T2DM OCD migraine hypothyroid CHD other_substance_misuse cannabis alcohol depression FH_anxiety FH_any FH_BPD FH_depression FH_LD FH_psychosis N_dep_b4 N_man_b4 anxiety stress hi_LDL lo_HDL weight CKD3 symptom_to_exposure dominant age_first_exposure age_first_reg age_first_diagnosis
1 0 0 0 1 0 0 female 0 ex smoker 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.0 healthy weight 0 0.000000 unclear 69.434631 64.391781 80.722794
2 0 0 0 0 0 1 female 1 ex smoker 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 1 0 1 0.0 healthy weight 0 24.128679 depression 43.627651 27.057534 41.752224
3 0 0 1 0 0 0 female 0 current smoker 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0.0 overweight 0 0.000000 mania 56.533882 50.843836 56.533882
4 0 0 1 0 0 1 female 1 ex smoker 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 12 1 0 0 0 0.0 healthy weight 0 10.973306 depression 26.524298 22.583562 21.207392
5 0 0 0 1 0 0 female 0 current smoker 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 2 0 0 0 0 0.0 healthy weight 0 13.349760 depression 72.032852 67.394521 82.948662
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
38950 0 0 0 0 1 1 female 0 current smoker 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0.0 healthy weight 1 18.521561 unclear 74.286102 45.613699 55.764545
38951 0 0 0 1 0 0 female 0 never smoker 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 1 0 0 0.0 healthy weight 0 7.293634 depression 43.791924 44.260274 46.135525
38953 0 0 0 0 1 0 female 1 current smoker 0 0 0 0 0 1 1 0 1 0 1 0 0 0 0 13 0 1 1 0 0.0 overweight 0 11.707050 depression 37.223820 45.471233 28.867899
38954 0 0 1 0 0 0 female 0 ex smoker 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0.0 healthy weight 0 16.922655 mania 48.421631 39.983562 31.498974
38956 0 0 0 1 0 0 female 0 current smoker 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 2 0 0 0 0 0.0 healthy weight 0 14.631075 depression 39.129364 30.767123 38.529774

31518 rows × 37 columns

In [21]:
key = '34_exp'
import itertools
variables_pearson = X_dict['num_' + key].columns
for variables in itertools.combinations(variables_pearson, 2):
    plt.figure(figsize=(16,9))
    sns.regplot(x = variables[0], y = variables[1], data = X_dict['num_'+key],
            line_kws={"color": "red"})
In [22]:
display(X_dict['num_'+ key])
age_first_exposure age_first_diagnosis symptom_to_exposure
1 69.434631 80.722794 0.000000
2 43.627651 41.752224 24.128679
3 56.533882 56.533882 0.000000
4 26.524298 21.207392 10.973306
5 72.032852 82.948662 13.349760
... ... ... ...
38950 74.286102 55.764545 18.521561
38951 43.791924 46.135525 7.293634
38953 37.223820 28.867899 11.707050
38954 48.421631 31.498974 16.922655
38956 39.129364 38.529774 14.631075

31369 rows × 3 columns

Correlation coefficients

With target

In [23]:
from pyitlib import discrete_random_variable as drv
from scipy.spatial import distance

def similarity(X,Y):
    sim = ((X*Y)+((1-X)*(1-Y)))/len(X)
    return(sim.sum())

def coeff(X_dict, y, key, feature):
    print(len(features[feature]), 'features')
    df_coeff = pd.DataFrame(index=features[feature])

    # Now we calculate both for each numerical feature
    for feat in features[feature]:
        # We calculate correlation coeffs only for numerical and binary categorical
        if (feat in X_dict['num_' + key].columns):
            df_coeff.loc[feat, 'pearson_r'], df_coeff.loc[feat, 'pearson_p'] = pearsonr(X[key][feat], y[key])
            df_coeff.loc[feat, 'spearman_r'], df_coeff.loc[feat, 'spearman_p'] = spearmanr(X[key][feat], y[key])
        # We calculate conditional entropy for only categorical (including non binary)
        if feat in X_dict['cat_'+ key].columns:
            df_coeff.loc[feat, 'cond_entropy'] = drv.entropy_conditional([int(x) for x in X[key][feat]], [int(x) for x in y[key]], base=np.e)
        # For binary variables the Jaccard similarity is more approriate
        if (feat in X_dict['bin_' + key].columns) and (y[key].dtype != np.float64):
            df_coeff.loc[feat, 'jaccard'] = 1-distance.jaccard(X[key][feat], y[key])
            df_coeff.loc[feat, 'SMC'] = similarity(X[key][feat], y[key])

    # We display the sorted absolute values only when p-value < 0.01
    display(df_coeff
            .reindex(df_coeff.pearson_r.abs().sort_values(ascending=False).index))
    return df_coeff

For binary features:

  • Jaccard coefficient only counts how many "1" are similar (eg, used a lot when analysing shopping similarities between shoppers).
    • 0 = feature and target are never both "1" for each patient
    • 1 = feature and target are always both "1" together for each patient
  • SMC (Simple Similarity Matching Coeff) counts both "1" and "0"
    • 0 = feature and target are never the same
    • 1 = feature and target are always the same
In [24]:
feature = '34'
for target in targets:
    key = feature + '_' + target
    print(10*'_' + target + 10*'_')
    df_coeff = coeff(X_dict, y, key, feature)
    print()
__________resp__________
34 features
pearson_r pearson_p spearman_r spearman_p cond_entropy jaccard SMC
age_first_exposure 0.193293 1.693068e-221 0.195700 4.192704e-227 NaN NaN NaN
age_first_diagnosis 0.131446 1.780416e-102 0.130364 8.174834e-101 NaN NaN NaN
symptom_to_exposure 0.122443 4.498120e-89 0.119015 3.168104e-84 NaN NaN NaN
psychosis NaN NaN NaN NaN 0.589143 0.209603 0.529292
depression NaN NaN NaN NaN 0.676162 0.345302 0.495832
mania NaN NaN NaN NaN 0.574144 0.197438 0.527255
dominant NaN NaN NaN NaN 0.989417 NaN NaN
sex NaN NaN NaN NaN 0.679791 0.274790 0.509186
FH_BPD NaN NaN NaN NaN 0.086441 0.015442 0.552642
FH_depression NaN NaN NaN NaN 0.095472 0.017635 0.552416
FH_psychosis NaN NaN NaN NaN 0.044860 0.006384 0.553774
weight NaN NaN NaN NaN 1.014475 NaN NaN
self_harm NaN NaN NaN NaN 0.391276 0.108984 0.536459
cannabis NaN NaN NaN NaN 0.076336 0.011893 0.551813
anxiety NaN NaN NaN NaN 0.548561 0.172467 0.518352
stress NaN NaN NaN NaN 0.300254 0.064856 0.531140
sleep NaN NaN NaN NaN 0.379208 0.104652 0.537516
other_substance_misuse NaN NaN NaN NaN 0.192186 0.044094 0.549398
relationship NaN NaN NaN NaN 0.392561 0.113042 0.539741
OCD NaN NaN NaN NaN 0.098012 0.019634 0.553586
adhd NaN NaN NaN NaN 0.033642 0.004212 0.554038
smoker NaN NaN NaN NaN 1.065933 NaN NaN
alcohol NaN NaN NaN NaN 0.227904 0.049326 0.542684
FH_suicide NaN NaN NaN NaN 0.026897 0.004309 0.555396
hi_LDL NaN NaN NaN NaN 0.229586 0.055144 0.547550
lo_HDL NaN NaN NaN NaN 0.119419 0.025348 0.553246
CKD3 NaN NaN NaN NaN 0.092342 0.019442 0.554793
T2DM NaN NaN NaN NaN 0.175593 0.049850 0.559244
migraine NaN NaN NaN NaN 0.203305 0.044653 0.546418
hypothyroid NaN NaN NaN NaN 0.201963 0.057541 0.558226
CHD NaN NaN NaN NaN 0.144386 0.036291 0.556226
FH_anxiety NaN NaN NaN NaN 0.009570 0.001187 0.555509
FH_any NaN NaN NaN NaN 0.231687 0.052400 0.544306
FH_LD NaN NaN NaN NaN 0.002625 0.000594 0.555924
__________exp__________
34 features
pearson_r pearson_p spearman_r spearman_p cond_entropy jaccard SMC
age_first_exposure -0.106543 7.323420e-80 -0.108190 2.643016e-82 NaN NaN NaN
symptom_to_exposure -0.075266 1.201673e-40 -0.082721 9.258229e-49 NaN NaN NaN
age_first_diagnosis -0.059584 4.470683e-26 -0.054850 2.440734e-22 NaN NaN NaN
psychosis NaN NaN NaN NaN 0.589280 0.219806 0.570021
depression NaN NaN NaN NaN 0.676135 0.315729 0.487009
mania NaN NaN NaN NaN 0.572370 0.204563 0.567630
dominant NaN NaN NaN NaN 0.988610 NaN NaN
sex NaN NaN NaN NaN 0.678129 0.278029 0.541394
FH_BPD NaN NaN NaN NaN 0.085809 0.010562 0.596831
FH_depression NaN NaN NaN NaN 0.096748 0.019133 0.601231
FH_psychosis NaN NaN NaN NaN 0.046308 0.011288 0.606299
weight NaN NaN NaN NaN 1.010127 NaN NaN
self_harm NaN NaN NaN NaN 0.391315 0.123489 0.588415
cannabis NaN NaN NaN NaN 0.075429 0.027083 0.611782
anxiety NaN NaN NaN NaN 0.541934 0.224798 0.600179
stress NaN NaN NaN NaN 0.293544 0.126974 0.624534
sleep NaN NaN NaN NaN 0.373184 0.150662 0.615416
other_substance_misuse NaN NaN NaN NaN 0.185501 0.074975 0.619274
relationship NaN NaN NaN NaN 0.388881 0.092250 0.562402
OCD NaN NaN NaN NaN 0.095491 0.025091 0.606108
adhd NaN NaN NaN NaN 0.033051 0.007548 0.605980
smoker NaN NaN NaN NaN 1.060558 NaN NaN
alcohol NaN NaN NaN NaN 0.222606 0.075436 0.609678
FH_suicide NaN NaN NaN NaN 0.027568 0.004173 0.604386
hi_LDL NaN NaN NaN NaN 0.223190 0.087069 0.617616
lo_HDL NaN NaN NaN NaN 0.117458 0.036801 0.609519
CKD3 NaN NaN NaN NaN 0.093997 0.023179 0.605024
T2DM NaN NaN NaN NaN 0.175785 0.039021 0.595684
migraine NaN NaN NaN NaN 0.203515 0.056255 0.601039
hypothyroid NaN NaN NaN NaN 0.200442 0.037254 0.586439
CHD NaN NaN NaN NaN 0.145779 0.035962 0.601772
FH_anxiety NaN NaN NaN NaN 0.010400 0.001532 0.605183
FH_any NaN NaN NaN NaN 0.233395 0.057539 0.592719
FH_LD NaN NaN NaN NaN 0.002539 0.000565 0.605502
__________multi__________
34 features
pearson_r pearson_p spearman_r spearman_p cond_entropy jaccard SMC
age_first_exposure 0.176324 1.914504e-217 0.177382 4.462611e-220 NaN NaN NaN
age_first_diagnosis 0.120163 3.210233e-101 0.118399 2.643382e-98 NaN NaN NaN
symptom_to_exposure 0.112088 3.062966e-88 0.107891 7.372539e-82 NaN NaN NaN
psychosis NaN NaN NaN NaN 0.590599 0.067322 0.327106
depression NaN NaN NaN NaN 0.676158 0.113832 0.584686
mania NaN NaN NaN NaN 0.573145 0.060513 0.306162
dominant NaN NaN NaN NaN 0.989188 NaN NaN
sex NaN NaN NaN NaN 0.679341 0.087408 0.432688
FH_BPD NaN NaN NaN NaN 0.086569 0.004969 0.106092
FH_depression NaN NaN NaN NaN 0.096706 0.006140 0.108260
FH_psychosis NaN NaN NaN NaN 0.046645 0.002923 0.098919
weight NaN NaN NaN NaN 1.014602 NaN NaN
self_harm NaN NaN NaN NaN 0.391790 0.035012 0.194523
cannabis NaN NaN NaN NaN 0.078415 0.005153 0.102745
anxiety NaN NaN NaN NaN 0.547361 0.055399 0.268067
stress NaN NaN NaN NaN 0.301458 0.024737 0.142657
sleep NaN NaN NaN NaN 0.379355 0.033202 0.189773
other_substance_misuse NaN NaN NaN NaN 0.191935 0.013252 0.131276
relationship NaN NaN NaN NaN 0.389897 0.032445 0.199050
OCD NaN NaN NaN NaN 0.096090 0.004843 0.110077
adhd NaN NaN NaN NaN 0.033256 0.001434 0.096688
smoker NaN NaN NaN NaN 1.065385 NaN NaN
alcohol NaN NaN NaN NaN 0.224836 0.014706 0.131499
FH_suicide NaN NaN NaN NaN 0.027556 0.001437 0.097899
hi_LDL NaN NaN NaN NaN 0.228775 0.016350 0.141031
lo_HDL NaN NaN NaN NaN 0.119321 0.007289 0.115656
CKD3 NaN NaN NaN NaN 0.094304 0.006274 0.111671
T2DM NaN NaN NaN NaN 0.175073 0.011715 0.142625
migraine NaN NaN NaN NaN 0.203864 0.014731 0.130320
hypothyroid NaN NaN NaN NaN 0.200578 0.013537 0.149319
CHD NaN NaN NaN NaN 0.145722 0.010078 0.128311
FH_anxiety NaN NaN NaN NaN 0.010378 0.000600 0.095190
FH_any NaN NaN NaN NaN 0.233232 0.018060 0.137237
FH_LD NaN NaN NaN NaN 0.002521 0.000060 0.094807
__________lithium2y__________
34 features
pearson_r pearson_p spearman_r spearman_p cond_entropy jaccard SMC
age_first_exposure 0.154052 7.457636e-166 0.157628 1.206823e-173 NaN NaN NaN
age_first_diagnosis 0.104651 4.215269e-77 0.101220 3.181159e-72 NaN NaN NaN
symptom_to_exposure 0.103478 2.051218e-75 0.101697 6.819342e-73 NaN NaN NaN
psychosis NaN NaN NaN NaN 0.590447 0.142651 0.601294
depression NaN NaN NaN NaN 0.676061 0.222709 0.462495
mania NaN NaN NaN NaN 0.572964 0.137856 0.610635
dominant NaN NaN NaN NaN 0.988927 NaN NaN
sex NaN NaN NaN NaN 0.679087 0.176928 0.530779
FH_BPD NaN NaN NaN NaN 0.086570 0.018090 0.738723
FH_depression NaN NaN NaN NaN 0.096748 0.018281 0.736364
FH_psychosis NaN NaN NaN NaN 0.046481 0.004525 0.740540
weight NaN NaN NaN NaN 1.014417 NaN NaN
self_harm NaN NaN NaN NaN 0.391424 0.082263 0.672097
cannabis NaN NaN NaN NaN 0.077634 0.006322 0.734419
anxiety NaN NaN NaN NaN 0.546391 0.115886 0.611113
stress NaN NaN NaN NaN 0.299668 0.041727 0.683732
sleep NaN NaN NaN NaN 0.378212 0.070755 0.670088
other_substance_misuse NaN NaN NaN NaN 0.190062 0.022924 0.711945
relationship NaN NaN NaN NaN 0.389486 0.109298 0.690331
OCD NaN NaN NaN NaN 0.095951 0.014316 0.734419
adhd NaN NaN NaN NaN 0.033194 0.003209 0.742516
smoker NaN NaN NaN NaN 1.062893 NaN NaN
alcohol NaN NaN NaN NaN 0.223868 0.032913 0.706812
FH_suicide NaN NaN NaN NaN 0.027538 0.005091 0.744557
hi_LDL NaN NaN NaN NaN 0.227714 0.035373 0.707036
lo_HDL NaN NaN NaN NaN 0.119160 0.019302 0.731136
CKD3 NaN NaN NaN NaN 0.094344 0.016388 0.735950
T2DM NaN NaN NaN NaN 0.175457 0.046268 0.729924
migraine NaN NaN NaN NaN 0.203862 0.039913 0.717811
hypothyroid NaN NaN NaN NaN 0.200551 0.056950 0.728139
CHD NaN NaN NaN NaN 0.145969 0.029967 0.729638
FH_anxiety NaN NaN NaN NaN 0.010389 0.001001 0.745354
FH_any NaN NaN NaN NaN 0.233339 0.048752 0.713252
FH_LD NaN NaN NaN NaN 0.002609 0.000125 0.745991

Pairwise

In [25]:
from pandas import Series, DataFrame
import pandas as pd
import numpy as np
from collections import Counter
import os
from datetime import date
from sklearn.feature_selection import chi2
from scipy import stats
import seaborn as sns
import matplotlib.pylab as plt
from numpy import percentile
from sklearn.feature_selection import SelectKBest

def chisquare(Y):
    X = Y.loc[:, Y.dtypes == np.int64]
    column_names=X.columns
    chisqmatrix=pd.DataFrame(X,columns=column_names,index=column_names)

    outercnt=0
    innercnt=0
    for icol in column_names:

        for jcol in column_names:

           mycrosstab = pd.crosstab(X[icol], X[jcol])
           stat, p, dof, expected=stats.chi2_contingency(mycrosstab)
           chisqmatrix.iloc[outercnt,innercnt] = round(p,3)
           cntexpected = expected[expected<5].size
           perexpected = ((expected.size-cntexpected)/expected.size)*100

           if perexpected < 20:
                chisqmatrix.iloc[outercnt,innercnt] = 2
           if icol==jcol:
               chisqmatrix.iloc[outercnt,innercnt]=0.00
           innercnt = innercnt + 1
        outercnt = outercnt + 1
        innercnt = 0

    return chisqmatrix
In [26]:
df_corr_pearson = X_dict['num_'+ key].corr(method='pearson', min_periods=1)
df_corr_spearman = X_dict['num_' + key].corr(method='spearman', min_periods=1)
df_corr_chisquare = chisquare(X_dict['cat_' + key])
In [27]:
print(X_dict[key].dtypes)
age_first_exposure        float64
age_first_diagnosis       float64
symptom_to_exposure       float64
psychosis                   int64
depression                  int64
mania                       int64
dominant                    int64
sex                         int64
FH_BPD                      int64
FH_depression               int64
FH_psychosis                int64
weight                      int64
self_harm                   int64
cannabis                    int64
anxiety                     int64
stress                      int64
sleep                       int64
other_substance_misuse      int64
relationship                int64
OCD                         int64
adhd                        int64
smoker                      int64
alcohol                     int64
FH_suicide                  int64
hi_LDL                      int64
lo_HDL                      int64
CKD3                        int64
T2DM                        int64
migraine                    int64
hypothyroid                 int64
CHD                         int64
FH_anxiety                  int64
FH_any                      int64
FH_LD                       int64
dtype: object
In [28]:
import matplotlib.pyplot as plt
%matplotlib inline
sns.set_theme(style="white")
sns.set(font_scale=2)
def plot_matrix_corr(df_corr, fontsize=16, xsize=9, ysize=9):
    cmap = sns.diverging_palette(230, 20, as_cmap=True)
    plt.figure(figsize=(xsize,ysize))
    sns.heatmap(df_corr, xticklabels = df_corr.columns, yticklabels = df_corr.columns, annot=True,
                linewidths=0.5, cmap = cmap, fmt='.3f',
                annot_kws={
                'fontsize': fontsize,
                'fontweight': 'bold',
                'fontfamily': 'serif'}
                )

plot_matrix_corr(df_corr_pearson, 32, 16, 12)
plot_matrix_corr(df_corr_spearman, 32, 16, 12)
In [29]:
sns.set_theme(style="white")
plot_matrix_corr(df_corr_chisquare, 8, 22, 14)

About Condition Entropy.

Here, $X$ is the response and $Y$ is the feature being considered. Two properties:

  1. Conditional entropy equals zero
    $ {H} (Y|X)=0 $ if and only if the value of $Y$ is completely determined by the value of $X$.
  2. Conditional entropy of independent random variables
    Conversely, ${H} (Y|X)= {H} (Y)$ if and only if $Y$ and $X$ are independent random variables.

F-Test

In [30]:
import scipy.stats as stats

def anova(X_dict, key, target):
    anova_df = X_dict['num_' + key].merge(df[target], left_index=True, right_index=True)
    results_list = list()
    for feature in X_dict['num_' + key].keys():
        res = stats.f_oneway(anova_df.loc[anova_df[target]==0.0, feature].values,
                         anova_df.loc[anova_df[target]==1.0, feature].values)
        results_list = results_list + [[feature, res.statistic, res.pvalue]]
    return (pd.DataFrame(results_list, columns=['feature', 'F-Statistic', 'p-value']))
In [31]:
feature = '34'
for target in targets:
    key = feature + '_' + target
    print('key:', key)
    display(anova(X_dict, key, targets[target]))
key: 34_resp
feature F-Statistic p-value
0 age_first_exposure 1028.801036 1.693068e-221
1 age_first_diagnosis 466.040468 1.780416e-102
2 symptom_to_exposure 403.450918 4.498120e-89
key: 34_exp
feature F-Statistic p-value
0 age_first_exposure 360.151148 7.323420e-80
1 age_first_diagnosis 111.756084 4.470683e-26
2 symptom_to_exposure 178.703821 1.201673e-40
key: 34_multi
feature F-Statistic p-value
0 age_first_exposure 46.909192 7.655806e-12
1 age_first_diagnosis 27.350178 1.714934e-07
2 symptom_to_exposure 13.520850 2.365581e-04
key: 34_lithium2y
feature F-Statistic p-value
0 age_first_exposure 762.493116 7.457636e-166
1 age_first_diagnosis 347.333055 4.215269e-77
2 symptom_to_exposure 339.500883 2.051218e-75

T-Test

In [32]:
import scipy.stats as stats

for target in targets:
    key = feature + '_' + target
    print('key:', key)
    sample = 1000
    ttest_df = X_dict['num_' + key].merge(df[targets[target]], left_index=True, right_index=True)
    results_list = list()
    for feat in X_dict['num_' + key].keys():
        res = stats.ttest_ind(ttest_df.loc[ttest_df[targets[target]]==0.0, feat].head(sample),
                              ttest_df.loc[ttest_df[targets[target]]==1.0, feat].head(sample),
                              equal_var = True)
        results_list = results_list + [[feat, res.statistic, res.pvalue]]

    ttest_results = pd.DataFrame(results_list, columns=['feature', 'T-Statistic', 'p-value'])
    display(ttest_results)
key: 34_resp
feature T-Statistic p-value
0 age_first_exposure -8.560086 2.212296e-17
1 age_first_diagnosis -4.615607 4.168509e-06
2 symptom_to_exposure -6.759800 1.807514e-11
key: 34_exp
feature T-Statistic p-value
0 age_first_exposure 5.359331 9.317705e-08
1 age_first_diagnosis 3.378131 7.438096e-04
2 symptom_to_exposure 3.309470 9.513413e-04
key: 34_multi
feature T-Statistic p-value
0 age_first_exposure -3.410012 0.000663
1 age_first_diagnosis -2.348080 0.018967
2 symptom_to_exposure -2.479474 0.013240
key: 34_lithium2y
feature T-Statistic p-value
0 age_first_exposure -7.933297 3.521460e-15
1 age_first_diagnosis -3.810229 1.430358e-04
2 symptom_to_exposure -6.630426 4.294092e-11

Dataset Splitting

We split by time (exposure start). The 80% first patients as training set, the 20% last patients as test set.

In [33]:
x_train = dict()
y_train = dict()
x_test = dict()
y_test = dict()

format='%d%b%Y'
df['exposure_start'] = pd.to_datetime(df['exposure_start'], format=format)
for feature_set in X_dict:
    print(feature_set)
    X_sorted = X_dict[feature_set].merge(df['exposure_start'], left_index=True, right_index=True).sort_values(by = 'exposure_start')
    y_sorted = y_dict[feature_set][X_sorted.index]
    limit = int(len(X_sorted)*.8) # We keep 80% for the training set
    train_index = X_sorted.index[:limit]
    test_index = X_sorted.index[limit:]
    x_train[feature_set] = X_dict[feature_set][X_dict[feature_set].index.isin(train_index)]
    y_train[feature_set] = y_dict[feature_set][y_dict[feature_set].index.isin(train_index)]
    x_test[feature_set] = X_dict[feature_set][X_dict[feature_set].index.isin(test_index)]
    y_test[feature_set] = y_dict[feature_set][y_dict[feature_set].index.isin(test_index)]
34_resp
num_34_resp
cat_34_resp
bin_34_resp
lit_34_resp
ola_34_resp
34_exp
num_34_exp
cat_34_exp
bin_34_exp
34_multi
num_34_multi
cat_34_multi
bin_34_multi
lit_34_multi
ola_34_multi
34_lithium2y
num_34_lithium2y
cat_34_lithium2y
bin_34_lithium2y
37_resp
num_37_resp
cat_37_resp
bin_37_resp
lit_37_resp
ola_37_resp
37_exp
num_37_exp
cat_37_exp
bin_37_exp
37_multi
num_37_multi
cat_37_multi
bin_37_multi
lit_37_multi
ola_37_multi
37_lithium2y
num_37_lithium2y
cat_37_lithium2y
bin_37_lithium2y
13_resp
num_13_resp
cat_13_resp
bin_13_resp
lit_13_resp
ola_13_resp
13_exp
num_13_exp
cat_13_exp
bin_13_exp
13_multi
num_13_multi
cat_13_multi
bin_13_multi
lit_13_multi
ola_13_multi
13_lithium2y
num_13_lithium2y
cat_13_lithium2y
bin_13_lithium2y

Evaluation function

Balanced Accuracy is the same as macro-averaged Recall, see e.g.: https://github.com/EpistasisLab/tpot/issues/108 Also confirmed here: https://amueller.github.io/aml/04-model-evaluation/10-evaluation-metrics.html
In [92]:
from sklearn.metrics import roc_auc_score

def results(clf, X1, y, v):
    X = X1
    if isinstance(X1, pd.Series):
        X = X1.to_frame()

    y_pred = clf.predict(X)
    y_score = clf.predict_proba(X)[:, 1]
    target_names = ['No Response', 'Response']
    multiclass = 'raise'

    if '_multi' in v:
        y_score = clf.predict_proba(X)
        multiclass = 'ovr'
        target_names = ['No Response', 'Equivocal', 'Response']
    if '_exp' in v:
        target_names = ['Lithium', 'Olanzapine']
    if '_lithium2y' in v:
        target_names = ['other', 'lithium2y']

    #print(classification_report(y, y_pred, target_names=target_names))
    result = classification_report(y, y_pred, target_names=target_names, output_dict=True)
    # Compute confusion matrix
    cm = confusion_matrix(y, y_pred)
    #print(cm)

    print('Balanced Accuracy:', result['macro avg']['recall'])

    # For average and multi_class parameters, check doc:
    # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score
    try:
        roc_auc_score_val = roc_auc_score(y, y_score, average='weighted', multi_class=multiclass)
    except:
        # 34_multi fails with Naive Bayes because MixedNB predict_proba() seems to return proba values that don't sum up to 1...
        roc_auc_score_val = None
    result['roc_auc'] = roc_auc_score_val
    print('ROC_AUC score:', roc_auc_score_val)

    # Show confusion matrix in a separate window
    color_map = plt.cm.get_cmap('Blues')
    plt.matshow(cm, cmap=color_map)
    plt.title('Confusion matrix\n')
    plt.colorbar()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

    return(result)
In [76]:
import warnings

def evaluate(clf, x_test, y_test):
    all_results = pd.DataFrame(columns=['features',
                                        'balanced accuracy',
                                        'accuracy',
                                        'roc_auc',
                                    'f1 (response)',
                                    'f1 (equivocal)',
                                    'f1 (no response)',
                                    'f1 (lithium)',
                                    'f1 (olanzapine)',
                                    'f1 (lithium > 2y)',
                                    'f1 (other)',
                                    'f1 (weighted avg)'])
    for v in clf:
        print(40*'_' + v + 40*'_')
        v2 = v.replace('_balanced_accuracy', '').replace('_accuracy', '').replace('_f1_weighted', '')
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            result = results(clf[v], x_test[v2], y_test[v2], v)
        result_dict = {'features': v,
                        'accuracy': result['accuracy'],
                        'roc_auc': result['roc_auc'],
                       'balanced accuracy': result['macro avg']['recall'],
                        'f1 (weighted avg)': result['weighted avg']['f1-score']
                       }
        if '_exp' in v:
            result_dict.update({
                'f1 (lithium)': result['Lithium']['f1-score'],
                'f1 (olanzapine)': result['Olanzapine']['f1-score'],

            })
        elif '_multi' in v:
            result_dict.update({
                'f1 (no response)': result['No Response']['f1-score'],
                'f1 (equivocal)': result['Equivocal']['f1-score'],
                'f1 (response)': result['Response']['f1-score']
            })
        elif '_lithium2y' in v:
            result_dict.update({
                'f1 (other)': result['other']['f1-score'],
                'f1 (lithium > 2y)': result['lithium2y']['f1-score']
            })
        else:
            result_dict.update({
                'f1 (no response)': result['No Response']['f1-score'],
                'f1 (response)': result['Response']['f1-score']
            })

        all_results = all_results.append(result_dict,
                                         ignore_index=True)
        print(83*'_')
        print('\n\n\n')
    return all_results

Logistic Regression with Elastic Net regularisation

In [68]:
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import LogisticRegressionCV

@ignore_warnings(category=ConvergenceWarning) # max_iter default value (=100?) triggers this warning
def run(X1, y, l1, l2):
    X = X1
    if isinstance(X1, pd.Series):
        X = X1.to_frame()
    return LogisticRegressionCV(penalty='elasticnet',
                                l1_ratios=[l1, l2],
                                cv=5,
                                solver='saga',
                                scoring='balanced_accuracy',
                                n_jobs=4).fit(X, y)
In [69]:
clf_multi_dict = dict()

# Elastic net regularisation for larger feature sets
l1 = .8
l2 = .2

for v in [f for f in X_dict if len(X_dict[f].columns) > 5]:
    print(v)
    clf_multi_dict[v] = run(x_train[v], y_train[v], l1, l2) # a lot of features (more than 5)

for v in [f for f in X_dict if len(X_dict[f].columns) <= 5]:
    print(v)
    clf_multi_dict[v] = run(x_train[v], y_train[v], l2, l1) # NOT a lot of features (5 or less)
34_resp
cat_34_resp
bin_34_resp
lit_34_resp
ola_34_resp
34_exp
cat_34_exp
bin_34_exp
34_multi
cat_34_multi
bin_34_multi
lit_34_multi
ola_34_multi
34_lithium2y
cat_34_lithium2y
bin_34_lithium2y
37_resp
cat_37_resp
bin_37_resp
lit_37_resp
ola_37_resp
37_exp
cat_37_exp
bin_37_exp
37_multi
cat_37_multi
bin_37_multi
lit_37_multi
ola_37_multi
37_lithium2y
cat_37_lithium2y
bin_37_lithium2y
13_resp
cat_13_resp
bin_13_resp
lit_13_resp
ola_13_resp
13_exp
cat_13_exp
bin_13_exp
13_multi
cat_13_multi
bin_13_multi
lit_13_multi
ola_13_multi
13_lithium2y
cat_13_lithium2y
bin_13_lithium2y
num_34_resp
num_34_exp
num_34_multi
num_34_lithium2y
num_37_resp
num_37_exp
num_37_multi
num_37_lithium2y
num_13_resp
num_13_exp
num_13_multi
num_13_lithium2y

Results (interpret ROC_AUC with care because of class imbalance)

In [77]:
all_results = evaluate(clf_multi_dict, x_test, y_test)
________________________________________34_resp________________________________________
Balanced Accuracy: 0.5807004496516058
ROC_AUC score: 0.6445436122012334
___________________________________________________________________________________




________________________________________cat_34_resp________________________________________
Balanced Accuracy: 0.5206806844315283
ROC_AUC score: 0.5627429177011705
___________________________________________________________________________________




________________________________________bin_34_resp________________________________________
Balanced Accuracy: 0.517216479216485
ROC_AUC score: 0.5665701765426024
___________________________________________________________________________________




________________________________________lit_34_resp________________________________________
Balanced Accuracy: 0.576376784572234
ROC_AUC score: 0.6225357816631882
___________________________________________________________________________________




________________________________________ola_34_resp________________________________________
Balanced Accuracy: 0.5623262698430254
ROC_AUC score: 0.6852710006119724
___________________________________________________________________________________




________________________________________34_exp________________________________________
Balanced Accuracy: 0.5567292093418088
ROC_AUC score: 0.597591654042374
___________________________________________________________________________________




________________________________________cat_34_exp________________________________________
Balanced Accuracy: 0.5528329793498448
ROC_AUC score: 0.585100734306033
___________________________________________________________________________________




________________________________________bin_34_exp________________________________________
Balanced Accuracy: 0.5471165025089209
ROC_AUC score: 0.5855459922694032
___________________________________________________________________________________




________________________________________34_multi________________________________________
Balanced Accuracy: 0.3857732604055686
ROC_AUC score: 0.6075739122604424
___________________________________________________________________________________




________________________________________cat_34_multi________________________________________
Balanced Accuracy: 0.34700607363579766
ROC_AUC score: 0.5446438972107863
___________________________________________________________________________________




________________________________________bin_34_multi________________________________________
Balanced Accuracy: 0.3461819988708128
ROC_AUC score: 0.5478197984652663
___________________________________________________________________________________




________________________________________lit_34_multi________________________________________
Balanced Accuracy: 0.38412353361957013
ROC_AUC score: 0.5939621512078251
___________________________________________________________________________________




________________________________________ola_34_multi________________________________________
Balanced Accuracy: 0.37496773392901467
ROC_AUC score: 0.6348224798138037
___________________________________________________________________________________




________________________________________34_lithium2y________________________________________
Balanced Accuracy: 0.5138532834389381
ROC_AUC score: 0.6187926946544635
___________________________________________________________________________________




________________________________________cat_34_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5
___________________________________________________________________________________




________________________________________bin_34_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5
___________________________________________________________________________________




________________________________________37_resp________________________________________
Balanced Accuracy: 0.5830973043786685
ROC_AUC score: 0.6499600977105525
___________________________________________________________________________________




________________________________________cat_37_resp________________________________________
Balanced Accuracy: 0.5210064816192035
ROC_AUC score: 0.5626445206576584
___________________________________________________________________________________




________________________________________bin_37_resp________________________________________
Balanced Accuracy: 0.517216479216485
ROC_AUC score: 0.5665701765426026
___________________________________________________________________________________




________________________________________lit_37_resp________________________________________
Balanced Accuracy: 0.5786410686683147
ROC_AUC score: 0.6238457544171984
___________________________________________________________________________________




________________________________________ola_37_resp________________________________________
Balanced Accuracy: 0.5715256733541256
ROC_AUC score: 0.6935510628414695
___________________________________________________________________________________




________________________________________37_exp________________________________________
Balanced Accuracy: 0.5539152712838864
ROC_AUC score: 0.5994841445594751
___________________________________________________________________________________




________________________________________cat_37_exp________________________________________
Balanced Accuracy: 0.5491100712496624
ROC_AUC score: 0.5844461374697395
___________________________________________________________________________________




________________________________________bin_37_exp________________________________________
Balanced Accuracy: 0.5471165025089209
ROC_AUC score: 0.5855459398429387
___________________________________________________________________________________




________________________________________37_multi________________________________________
Balanced Accuracy: 0.3875424076373301
ROC_AUC score: 0.6095569396838123
___________________________________________________________________________________




________________________________________cat_37_multi________________________________________
Balanced Accuracy: 0.34708807903337985
ROC_AUC score: 0.5440369409509134
___________________________________________________________________________________




________________________________________bin_37_multi________________________________________
Balanced Accuracy: 0.3461819988708128
ROC_AUC score: 0.5478197389143789
___________________________________________________________________________________




________________________________________lit_37_multi________________________________________
Balanced Accuracy: 0.3850145106903992
ROC_AUC score: 0.5951377257494506
___________________________________________________________________________________




________________________________________ola_37_multi________________________________________
Balanced Accuracy: 0.38150595681310495
ROC_AUC score: 0.6409861821257834
___________________________________________________________________________________




________________________________________37_lithium2y________________________________________
Balanced Accuracy: 0.512551877672003
ROC_AUC score: 0.619842045721224
___________________________________________________________________________________




________________________________________cat_37_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5
___________________________________________________________________________________




________________________________________bin_37_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5
___________________________________________________________________________________




________________________________________13_resp________________________________________
Balanced Accuracy: 0.5822266049587533
ROC_AUC score: 0.6433300963375705
___________________________________________________________________________________




________________________________________cat_13_resp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5031357621766341
___________________________________________________________________________________




________________________________________bin_13_resp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5
___________________________________________________________________________________




________________________________________lit_13_resp________________________________________
Balanced Accuracy: 0.576376784572234
ROC_AUC score: 0.6225498156535128
___________________________________________________________________________________




________________________________________ola_13_resp________________________________________
Balanced Accuracy: 0.5519614457298325
ROC_AUC score: 0.6735587309311567
___________________________________________________________________________________




________________________________________13_exp________________________________________
Balanced Accuracy: 0.519108292943985
ROC_AUC score: 0.5672112518097616
___________________________________________________________________________________




________________________________________cat_13_exp________________________________________
Balanced Accuracy: 0.5054545542240632
ROC_AUC score: 0.5338737348445577
___________________________________________________________________________________




________________________________________bin_13_exp________________________________________
Balanced Accuracy: 0.5003936178957946
ROC_AUC score: 0.5499176485094946
___________________________________________________________________________________




________________________________________13_multi________________________________________
Balanced Accuracy: 0.3878394785017754
ROC_AUC score: 0.6070247192189846
___________________________________________________________________________________




________________________________________cat_13_multi________________________________________
Balanced Accuracy: 0.3333333333333333
ROC_AUC score: 0.5003737582196537
___________________________________________________________________________________




________________________________________bin_13_multi________________________________________
Balanced Accuracy: 0.3333333333333333
ROC_AUC score: 0.5
___________________________________________________________________________________




________________________________________lit_13_multi________________________________________
Balanced Accuracy: 0.38402994192011985
ROC_AUC score: 0.5939708088740229
___________________________________________________________________________________




________________________________________ola_13_multi________________________________________
Balanced Accuracy: 0.36855795482750064
ROC_AUC score: 0.6250043455668536
___________________________________________________________________________________




________________________________________13_lithium2y________________________________________
Balanced Accuracy: 0.5084324333976145
ROC_AUC score: 0.6024577644964135
___________________________________________________________________________________




________________________________________cat_13_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5
___________________________________________________________________________________




________________________________________bin_13_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5
___________________________________________________________________________________




________________________________________num_34_resp________________________________________
Balanced Accuracy: 0.5824969107904944
ROC_AUC score: 0.6438724242285558
___________________________________________________________________________________




________________________________________num_34_exp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5455372076910043
___________________________________________________________________________________




________________________________________num_34_multi________________________________________
Balanced Accuracy: 0.38732593956959604
ROC_AUC score: 0.6073454289552751
___________________________________________________________________________________




________________________________________num_34_lithium2y________________________________________
Balanced Accuracy: 0.5045200853000296
ROC_AUC score: 0.6013688754374688
___________________________________________________________________________________




________________________________________num_37_resp________________________________________
Balanced Accuracy: 0.5860582258784225
ROC_AUC score: 0.64826060342559
___________________________________________________________________________________




________________________________________num_37_exp________________________________________
Balanced Accuracy: 0.5007556750599339
ROC_AUC score: 0.5560399061859873
___________________________________________________________________________________




________________________________________num_37_multi________________________________________
Balanced Accuracy: 0.3902299311659596
ROC_AUC score: 0.6105212605735918
___________________________________________________________________________________




________________________________________num_37_lithium2y________________________________________
Balanced Accuracy: 0.5076257308151458
ROC_AUC score: 0.6052576282816533
___________________________________________________________________________________




________________________________________num_13_resp________________________________________
Balanced Accuracy: 0.5824969107904944
ROC_AUC score: 0.6438352392993214
___________________________________________________________________________________




________________________________________num_13_exp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5454701018163882
___________________________________________________________________________________




________________________________________num_13_multi________________________________________
Balanced Accuracy: 0.38732593956959604
ROC_AUC score: 0.6073560975471447
___________________________________________________________________________________




________________________________________num_13_lithium2y________________________________________
Balanced Accuracy: 0.5044206026100176
ROC_AUC score: 0.6013454396114563
___________________________________________________________________________________




In [78]:
display(all_results)
features balanced accuracy accuracy roc_auc f1 (response) f1 (equivocal) f1 (no response) f1 (lithium) f1 (olanzapine) f1 (lithium > 2y) f1 (other) f1 (weighted avg)
0 34_resp 0.580700 0.600151 0.644544 0.416942 NaN 0.695752 NaN NaN NaN NaN 0.566286
1 cat_34_resp 0.520681 0.551113 0.562743 0.162562 NaN 0.693378 NaN NaN NaN NaN 0.446892
2 bin_34_resp 0.517216 0.548849 0.566570 0.131445 NaN 0.695287 NaN NaN NaN NaN 0.433465
3 lit_34_resp 0.576377 0.561359 0.622536 0.527458 NaN 0.590722 NaN NaN NaN NaN 0.555586
4 ola_34_resp 0.562326 0.635294 0.685271 0.297371 NaN 0.753734 NaN NaN NaN NaN 0.570544
5 34_exp 0.556729 0.522952 0.597592 NaN NaN NaN 0.564147 0.473156 NaN NaN 0.510661
6 cat_34_exp 0.552833 0.517692 0.585101 NaN NaN NaN 0.562717 0.462331 NaN NaN 0.503708
7 bin_34_exp 0.547117 0.510360 0.585546 NaN NaN NaN 0.560137 0.447879 NaN NaN 0.494150
8 34_multi 0.385773 0.509882 0.607574 0.394888 0.0 0.626716 NaN NaN NaN NaN 0.442636
9 cat_34_multi 0.347006 0.469238 0.544644 0.158200 0.0 0.623879 NaN NaN NaN NaN 0.347330
10 bin_34_multi 0.346182 0.468441 0.547820 0.150086 0.0 0.623939 NaN NaN NaN NaN 0.344135
11 lit_34_multi 0.384124 0.471301 0.593962 0.491928 0.0 0.528982 NaN NaN NaN NaN 0.427151
12 ola_34_multi 0.374968 0.542003 0.634822 0.290523 0.0 0.680728 NaN NaN NaN NaN 0.447247
13 34_lithium2y 0.513853 0.801084 0.618793 NaN NaN NaN NaN NaN 0.068657 0.888651 0.725541
14 cat_34_lithium2y 0.500000 0.801084 0.500000 NaN NaN NaN NaN NaN 0.000000 0.889558 0.712610
15 bin_34_lithium2y 0.500000 0.801084 0.500000 NaN NaN NaN NaN NaN 0.000000 0.889558 0.712610
16 37_resp 0.583097 0.602603 0.649960 0.419719 NaN 0.697835 NaN NaN NaN NaN 0.568691
17 cat_37_resp 0.521006 0.551867 0.562645 0.154448 NaN 0.695150 NaN NaN NaN NaN 0.444074
18 bin_37_resp 0.517216 0.548849 0.566570 0.131445 NaN 0.695287 NaN NaN NaN NaN 0.433465
19 lit_37_resp 0.578641 0.563247 0.623846 0.527891 NaN 0.593677 NaN NaN NaN NaN 0.557141
20 ola_37_resp 0.571526 0.643765 0.693551 0.316170 NaN 0.759147 NaN NaN NaN NaN 0.581331
21 37_exp 0.553915 0.518489 0.599484 NaN NaN NaN 0.564006 0.462360 NaN NaN 0.504256
22 cat_37_exp 0.549110 0.512432 0.584446 NaN NaN NaN 0.561685 0.450709 NaN NaN 0.496451
23 bin_37_exp 0.547117 0.510360 0.585546 NaN NaN NaN 0.560137 0.447879 NaN NaN 0.494150
24 37_multi 0.387542 0.512113 0.609557 0.398541 0.0 0.628659 NaN NaN NaN NaN 0.444973
25 cat_37_multi 0.347088 0.469398 0.544037 0.157211 0.0 0.624066 NaN NaN NaN NaN 0.347023
26 bin_37_multi 0.346182 0.468441 0.547820 0.150086 0.0 0.623939 NaN NaN NaN NaN 0.344135
27 lit_37_multi 0.385015 0.471827 0.595138 0.490578 0.0 0.530845 NaN NaN NaN NaN 0.427218
28 ola_37_multi 0.381506 0.548869 0.640986 0.313316 0.0 0.685266 NaN NaN NaN NaN 0.457389
29 37_lithium2y 0.512552 0.800446 0.619842 NaN NaN NaN NaN NaN 0.064275 0.888314 0.724399
30 cat_37_lithium2y 0.500000 0.801084 0.500000 NaN NaN NaN NaN NaN 0.000000 0.889558 0.712610
31 bin_37_lithium2y 0.500000 0.801084 0.500000 NaN NaN NaN NaN NaN 0.000000 0.889558 0.712610
32 13_resp 0.582227 0.599585 0.643330 0.439989 NaN 0.688390 NaN NaN NaN NaN 0.573044
33 cat_13_resp 0.500000 0.535647 0.503136 0.000000 NaN 0.697617 NaN NaN NaN NaN 0.373677
34 bin_13_resp 0.500000 0.535647 0.500000 0.000000 NaN 0.697617 NaN NaN NaN NaN 0.373677
35 lit_13_resp 0.576377 0.561359 0.622550 0.527458 NaN 0.590722 NaN NaN NaN NaN 0.555586
36 ola_13_resp 0.551961 0.626353 0.673559 0.272894 NaN 0.748575 NaN NaN NaN NaN 0.557631
37 13_exp 0.519108 0.441026 0.567211 NaN NaN NaN 0.586974 0.135568 NaN NaN 0.321628
38 cat_13_exp 0.505455 0.421103 0.533874 NaN NaN NaN 0.583963 0.048717 NaN NaN 0.269333
39 bin_13_exp 0.500394 0.413931 0.549918 NaN NaN NaN 0.582681 0.016056 NaN NaN 0.249606
40 13_multi 0.387839 0.510679 0.607025 0.417533 0.0 0.622216 NaN NaN NaN NaN 0.449578
41 cat_13_multi 0.333333 0.456009 0.500374 0.000000 0.0 0.626382 NaN NaN NaN NaN 0.285636
42 bin_13_multi 0.333333 0.456009 0.500000 0.000000 0.0 0.626382 NaN NaN NaN NaN 0.285636
43 lit_13_multi 0.384030 0.471301 0.593971 0.492570 0.0 0.528489 NaN NaN NaN NaN 0.427267
44 ola_13_multi 0.368558 0.535137 0.625004 0.265834 0.0 0.678005 NaN NaN NaN NaN 0.437382
45 13_lithium2y 0.508432 0.801084 0.602458 NaN NaN NaN NaN NaN 0.042945 0.889007 0.720712
46 cat_13_lithium2y 0.500000 0.801084 0.500000 NaN NaN NaN NaN NaN 0.000000 0.889558 0.712610
47 bin_13_lithium2y 0.500000 0.801084 0.500000 NaN NaN NaN NaN NaN 0.000000 0.889558 0.712610
48 num_34_resp 0.582497 0.599585 0.643872 0.442928 NaN 0.687472 NaN NaN NaN NaN 0.573918
49 num_34_exp 0.500000 0.412177 0.545537 NaN NaN NaN 0.583747 0.000000 NaN NaN 0.240607
50 num_34_multi 0.387326 0.509882 0.607345 0.418582 0.0 0.620580 NaN NaN NaN NaN 0.449249
51 num_34_lithium2y 0.504520 0.800606 0.601369 NaN NaN NaN NaN NaN 0.024942 0.888948 0.717083
52 num_37_resp 0.586058 0.603169 0.648261 0.447479 NaN 0.690406 NaN NaN NaN NaN 0.577602
53 num_37_exp 0.500756 0.413134 0.556040 NaN NaN NaN 0.584049 0.003788 NaN NaN 0.242958
54 num_37_multi 0.390230 0.513548 0.610521 0.425006 0.0 0.623219 NaN NaN NaN NaN 0.453004
55 num_37_lithium2y 0.507626 0.801721 0.605258 NaN NaN NaN NaN NaN 0.037152 0.889481 0.719939
56 num_13_resp 0.582497 0.599585 0.643835 0.442928 NaN 0.687472 NaN NaN NaN NaN 0.573918
57 num_13_exp 0.500000 0.412177 0.545470 NaN NaN NaN 0.583747 0.000000 NaN NaN 0.240607
58 num_13_multi 0.387326 0.509882 0.607356 0.418582 0.0 0.620580 NaN NaN NaN NaN 0.449249
59 num_13_lithium2y 0.504421 0.800446 0.601345 NaN NaN NaN NaN NaN 0.024922 0.888849 0.717000

SHAP

In [40]:
import shap

def shap_plot(clf, X, sample=50):
    explainer = shap.KernelExplainer(clf.predict, shap.sample(X, sample))
    shap_values = explainer.shap_values(shap.sample(X, sample), l1_reg="num_features("+str(X.shape[1])+")")
    shap.summary_plot(shap_values, shap.sample(X, sample), feature_names=X.columns)
In [41]:
shap_plot(clf_multi_dict['13_lithium2y'], X_dict['13_lithium2y'], 100)
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names
X does not have valid feature names, but LogisticRegressionCV was fitted with feature names

UMAP

In [42]:
def umap_embedding(X_dict, n_neighbors=15, weight=0.5):
    import umap.umap_ as umap
    fit1 = umap.UMAP(n_neighbors=n_neighbors, metric='braycurtis', random_state=42).fit(X_dict['num'].values)
    fit2 = umap.UMAP(n_neighbors=n_neighbors, metric='jaccard', random_state=42).fit(X_dict['bin'].values)
    intersection = umap.general_simplicial_set_intersection(fit1.graph_, fit2.graph_, weight=weight)
    intersection = umap.reset_local_connectivity(intersection)
    embedding = umap.simplicial_set_embedding(fit1._raw_data, intersection, fit1.n_components,
                                              fit1._initial_alpha, fit1._a, fit1._b,
                                              fit1.repulsion_strength, fit1.negative_sample_rate,
                                              200, 'random', np.random, fit1.metric,
                                              fit1._metric_kwds, False,
                                              densmap_kwds={}, output_dens=False)
    return embedding
In [43]:
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
sns.set(style='white', context='poster', rc={'figure.figsize':(14,10)})

def plot_umap(X_dict, n_neighbors, weight, min_dist=0.1, n_components=2):
    import umap.umap_ as umap
    import matplotlib.pyplot as plt
    print('n_neighbors={:.0f} min_dist={:.1f} weight={:.1f} {:.0f}D'.format(n_neighbors, min_dist, weight, n_components))
    fit1 = umap.UMAP(n_neighbors=n_neighbors, metric='braycurtis', random_state=42, min_dist=min_dist, n_components=n_components).fit(X_dict['num'].values)
    fit2 = umap.UMAP(n_neighbors=n_neighbors, metric='jaccard', random_state=42, min_dist=min_dist, n_components=n_components).fit(X_dict['bin'].values)
    intersection = umap.general_simplicial_set_intersection(fit1.graph_, fit2.graph_, weight=weight)
    intersection = umap.reset_local_connectivity(intersection)
    embedding = umap.simplicial_set_embedding(fit1._raw_data, intersection, fit1.n_components,
                                              fit1._initial_alpha, fit1._a, fit1._b,
                                              fit1.repulsion_strength, fit1.negative_sample_rate,
                                              200, 'random', np.random, fit1.metric,
                                              fit1._metric_kwds, False,
                                              densmap_kwds={}, output_dens=False
                                              )
    plt.clf()
    fig = plt.figure()
    if n_components == 3:
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter(
            embedding[0][:, 0],
            embedding[0][:, 1],
            embedding[0][:, 2],
            c=[sns.color_palette()[x] for x in X.exposure])
    else:
        ax = fig.add_subplot(111)
        ax.scatter(
            embedding[0][:, 0],
            embedding[0][:, 1],
            c=[sns.color_palette()[x] for x in X.exposure])
    #plt.gca().set_aspect('equal', 'datalim')
    plt.title('n_neighbors={:.0f} min_dist={:.1f} weight={:.1f} {:.0f}D'.format(n_neighbors, min_dist, weight, n_components), fontsize=16)
    plt.savefig('umap_images/n' + str(n_neighbors) + '_md{:.1f}_w{:.1f}_'.format(min_dist, weight) + str(n_components) + 'd.png')
In [105]:
def plot_umap_2y(X_dict, n_neighbors, weight, min_dist=0.1, n_components=2):
    import umap.umap_ as umap
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    print('n_neighbors={:.0f} min_dist={:.1f} weight={:.1f} {:.0f}D'.format(n_neighbors, min_dist, weight, n_components))
    fit1 = umap.UMAP(n_neighbors=n_neighbors, metric='braycurtis', random_state=42, min_dist=min_dist, n_components=n_components).fit(X_dict['num_13_lithium2y'].values)
    fit2 = umap.UMAP(n_neighbors=n_neighbors, metric='jaccard', random_state=42, min_dist=min_dist, n_components=n_components).fit(X_dict['cat_13_lithium2y'].values)
    intersection = umap.general_simplicial_set_intersection(fit1.graph_, fit2.graph_, weight=weight)
    intersection = umap.reset_local_connectivity(intersection)
    embedding = umap.simplicial_set_embedding(fit1._raw_data, intersection, fit1.n_components,
                                              fit1._initial_alpha, fit1._a, fit1._b,
                                              fit1.repulsion_strength, fit1.negative_sample_rate,
                                              200, 'random', np.random, fit1.metric,
                                              fit1._metric_kwds, False,
                                              densmap_kwds={}, output_dens=False
                                              )
    plt.clf()
    fig = plt.figure(figsize=(36,18))

    if n_components == 3:
        ax = fig.add_subplot(111, projection='3d')
        ax.scatter(embedding[0][:, 0],
                   embedding[0][:, 1],
                   embedding[0][:, 2],
                   c=[sns.color_palette()[col] for col in y_dict['13_lithium2y'].astype(int)])
        classes = ['Lithium for more than 2 years', 'Other']
        class_colours = sns.color_palette()
        recs = []
        for i in range(0, len(class_colours)):
            recs.append(mpatches.Rectangle((0, 0), 1, 1, fc=class_colours[i]))
            ax.legend(recs, classes, loc=1)

    else:
        ax = fig.add_subplot(111)
        ax.scatter(
            embedding[0][:, 0],
            embedding[0][:, 1],
            c=[sns.color_palette()[x] for x in y_dict['13_lithium2y'].astype(int)])
    #plt.gca().set_aspect('equal', 'datalim')
    plt.title('n_neighbors={:.0f} min_dist={:.1f} weight={:.1f} {:.0f}D'.format(n_neighbors, min_dist, weight, n_components), fontsize=16)

    plt.savefig('umap_images/n' + str(n_neighbors) + '_md{:.1f}_w{:.1f}_'.format(min_dist, weight) + str(n_components) + 'd_2y.png')
    return embedding
In [106]:
embedding = plot_umap_2y(X_dict, 50, 0, 0.1, 3)
n_neighbors=50 min_dist=0.1 weight=0.0 3D
gradient function is not yet implemented for jaccard distance metric; inverse_transform will be unavailable
Failed to correctly find n_neighbors for some samples.Results may be less than ideal. Try re-running withdifferent parameters.
A few of your vertices were disconnected from the manifold.  This shouldn't cause problems.
Disconnection_distance = 1 has removed 250 edges.
It has only fully disconnected 5 vertices.
Use umap.utils.disconnected_vertices() to identify them.
<Figure size 1008x720 with 0 Axes>
In [ ]:
for n_neighbors in range(50, 250, 20):
    for weight in np.linspace(0, 1, 5):
        for n_components in range(2,4):
            for min_dist in np.linspace(.1, .9, 4):
                plot_umap_2y(X_dict, n_neighbors, weight, min_dist=min_dist, n_components=n_components)
print('done')

Silhouette Score

In [ ]:
from sklearn.metrics import silhouette_samples, silhouette_score

PCA

In [ ]:
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
pca = PCA(n_components=9)

pipe = Pipeline([('pca', pca), ('logistic', clf_multi_dict['all34'])])
pipe.fit(X_dict['all34'], y)
#predictions = pipe.predict(X)
In [ ]:
plt.figure(1, figsize=(8, 6))
plt.clf()
plt.plot(pca.explained_variance_, linewidth=2)
plt.axis('tight')
plt.xlabel('n components')
plt.ylabel('explained variance')
In [ ]:
pd.set_option('display.float_format', lambda x: '%.3f' % x)
display(pd.DataFrame(pca.components_,columns=X_dict['all34'].columns))

Random Forest

In [49]:
from sklearn.ensemble import RandomForestClassifier

@ignore_warnings(category=ConvergenceWarning) # max_iter default value (=100?) triggers this warning
def run(X1, y):
    X = X1
    if isinstance(X1, pd.Series):
        X = X1.to_frame()
    return RandomForestClassifier(max_depth=2, random_state=0).fit(X, y)

clf2_multi_dict = dict()
for v in X_dict:
    print(v)
    clf2_multi_dict.update({v: run(X_dict[v], y_dict[v])})
34_resp
num_34_resp
cat_34_resp
bin_34_resp
lit_34_resp
ola_34_resp
34_exp
num_34_exp
cat_34_exp
bin_34_exp
34_multi
num_34_multi
cat_34_multi
bin_34_multi
lit_34_multi
ola_34_multi
34_lithium2y
num_34_lithium2y
cat_34_lithium2y
bin_34_lithium2y
37_resp
num_37_resp
cat_37_resp
bin_37_resp
lit_37_resp
ola_37_resp
37_exp
num_37_exp
cat_37_exp
bin_37_exp
37_multi
num_37_multi
cat_37_multi
bin_37_multi
lit_37_multi
ola_37_multi
37_lithium2y
num_37_lithium2y
cat_37_lithium2y
bin_37_lithium2y
13_resp
num_13_resp
cat_13_resp
bin_13_resp
lit_13_resp
ola_13_resp
13_exp
num_13_exp
cat_13_exp
bin_13_exp
13_multi
num_13_multi
cat_13_multi
bin_13_multi
lit_13_multi
ola_13_multi
13_lithium2y
num_13_lithium2y
cat_13_lithium2y
bin_13_lithium2y

Results

In [79]:
all_results2 = evaluate(clf2_multi_dict, x_test, y_test)
________________________________________34_resp________________________________________
Balanced Accuracy: 0.5110331975606687
ROC_AUC score: 0.6441694746055537
___________________________________________________________________________________




________________________________________num_34_resp________________________________________
Balanced Accuracy: 0.5773642464045033
ROC_AUC score: 0.6390544015514696
___________________________________________________________________________________




________________________________________cat_34_resp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5785016761821947
___________________________________________________________________________________




________________________________________bin_34_resp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5758893633940115
___________________________________________________________________________________




________________________________________lit_34_resp________________________________________
Balanced Accuracy: 0.6069482286096927
ROC_AUC score: 0.6300764451501537
___________________________________________________________________________________




________________________________________ola_34_resp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.6840147979384636
___________________________________________________________________________________




________________________________________34_exp________________________________________
Balanced Accuracy: 0.508012755987941
ROC_AUC score: 0.582772370162715
___________________________________________________________________________________




________________________________________num_34_exp________________________________________
Balanced Accuracy: 0.5187031412259908
ROC_AUC score: 0.5527076800995852
___________________________________________________________________________________




________________________________________cat_34_exp________________________________________
Balanced Accuracy: 0.509346904657651
ROC_AUC score: 0.5603262939270862
___________________________________________________________________________________




________________________________________bin_34_exp________________________________________
Balanced Accuracy: 0.5144463220108947
ROC_AUC score: 0.5749915488539155
___________________________________________________________________________________




________________________________________34_multi________________________________________
Balanced Accuracy: 0.34516011062967616
ROC_AUC score: 0.611353800587708
___________________________________________________________________________________




________________________________________num_34_multi________________________________________
Balanced Accuracy: 0.3847279954832512
ROC_AUC score: 0.605524634439378
___________________________________________________________________________________




________________________________________cat_34_multi________________________________________
Balanced Accuracy: 0.3333333333333333
ROC_AUC score: 0.56156310061291
___________________________________________________________________________________




________________________________________bin_34_multi________________________________________
Balanced Accuracy: 0.3333333333333333
ROC_AUC score: 0.5625488160989097
___________________________________________________________________________________




________________________________________lit_34_multi________________________________________
Balanced Accuracy: 0.40293891192888515
ROC_AUC score: 0.6005018451992608
___________________________________________________________________________________




________________________________________ola_34_multi________________________________________
Balanced Accuracy: 0.3333333333333333
ROC_AUC score: 0.6420242532566227
___________________________________________________________________________________




________________________________________34_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.6211246390563939
___________________________________________________________________________________




________________________________________num_34_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.6013454396114563
___________________________________________________________________________________




________________________________________cat_34_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5904663538063608
___________________________________________________________________________________




________________________________________bin_34_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5786242468650198
___________________________________________________________________________________




________________________________________37_resp________________________________________
Balanced Accuracy: 0.5845093305568586
ROC_AUC score: 0.6580121365888261
___________________________________________________________________________________




________________________________________num_37_resp________________________________________
Balanced Accuracy: 0.6025075514010136
ROC_AUC score: 0.6566489513849956
___________________________________________________________________________________




________________________________________cat_37_resp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5794025526023729
___________________________________________________________________________________




________________________________________bin_37_resp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5777174317227491
___________________________________________________________________________________




________________________________________lit_37_resp________________________________________
Balanced Accuracy: 0.601644382694887
ROC_AUC score: 0.6358305816688019
___________________________________________________________________________________




________________________________________ola_37_resp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.6925096035450168
___________________________________________________________________________________




________________________________________37_exp________________________________________
Balanced Accuracy: 0.5125252066441527
ROC_AUC score: 0.581302279670443
___________________________________________________________________________________




________________________________________num_37_exp________________________________________
Balanced Accuracy: 0.519850861387783
ROC_AUC score: 0.5650832616139299
___________________________________________________________________________________




________________________________________cat_37_exp________________________________________
Balanced Accuracy: 0.5111294044521393
ROC_AUC score: 0.5589588020259263
___________________________________________________________________________________




________________________________________bin_37_exp________________________________________
Balanced Accuracy: 0.5129434649782829
ROC_AUC score: 0.5686375662041394
___________________________________________________________________________________




________________________________________37_multi________________________________________
Balanced Accuracy: 0.38740060468928744
ROC_AUC score: 0.6205182267712338
___________________________________________________________________________________




________________________________________num_37_multi________________________________________
Balanced Accuracy: 0.4014441459086413
ROC_AUC score: 0.6173249209650806
___________________________________________________________________________________




________________________________________cat_37_multi________________________________________
Balanced Accuracy: 0.3333333333333333
ROC_AUC score: 0.5630210871861139
___________________________________________________________________________________




________________________________________bin_37_multi________________________________________
Balanced Accuracy: 0.3333333333333333
ROC_AUC score: 0.56365325092559
___________________________________________________________________________________




________________________________________lit_37_multi________________________________________
Balanced Accuracy: 0.4018588955331041
ROC_AUC score: 0.6062636060009988
___________________________________________________________________________________




________________________________________ola_37_multi________________________________________
Balanced Accuracy: 0.3333333333333333
ROC_AUC score: 0.6479986686804441
___________________________________________________________________________________




________________________________________37_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.6199223971246952
___________________________________________________________________________________




________________________________________num_37_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.60939492842348
___________________________________________________________________________________




________________________________________cat_37_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5880130054485904
___________________________________________________________________________________




________________________________________bin_37_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5778896851755487
___________________________________________________________________________________




________________________________________13_resp________________________________________
Balanced Accuracy: 0.5477520280088328
ROC_AUC score: 0.637369066143408
___________________________________________________________________________________




________________________________________num_13_resp________________________________________
Balanced Accuracy: 0.5788939771856157
ROC_AUC score: 0.6401284024210251
___________________________________________________________________________________




________________________________________cat_13_resp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5358554965046167
___________________________________________________________________________________




________________________________________bin_13_resp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5100019593597326
___________________________________________________________________________________




________________________________________lit_13_resp________________________________________
Balanced Accuracy: 0.5995958210786525
ROC_AUC score: 0.6204130403838095
___________________________________________________________________________________




________________________________________ola_13_resp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.674893734285946
___________________________________________________________________________________




________________________________________13_exp________________________________________
Balanced Accuracy: 0.5095903731589922
ROC_AUC score: 0.552450056452817
___________________________________________________________________________________




________________________________________num_13_exp________________________________________
Balanced Accuracy: 0.5182964167140602
ROC_AUC score: 0.5508419270794014
___________________________________________________________________________________




________________________________________cat_13_exp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5219919057732861
___________________________________________________________________________________




________________________________________bin_13_exp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5737315312050705
___________________________________________________________________________________




________________________________________13_multi________________________________________
Balanced Accuracy: 0.36479965529680997
ROC_AUC score: 0.605131810802567
___________________________________________________________________________________




________________________________________num_13_multi________________________________________
Balanced Accuracy: 0.38581131764253085
ROC_AUC score: 0.6047799299462205
___________________________________________________________________________________




________________________________________cat_13_multi________________________________________
Balanced Accuracy: 0.3333333333333333
ROC_AUC score: 0.5304525184350798
___________________________________________________________________________________




________________________________________bin_13_multi________________________________________
Balanced Accuracy: 0.3333333333333333
ROC_AUC score: 0.5108637439107558
___________________________________________________________________________________




________________________________________lit_13_multi________________________________________
Balanced Accuracy: 0.3999337699531935
ROC_AUC score: 0.593113733735595
___________________________________________________________________________________




________________________________________ola_13_multi________________________________________
Balanced Accuracy: 0.3333333333333333
ROC_AUC score: 0.6325575057349074
___________________________________________________________________________________




________________________________________13_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.6051245861265011
___________________________________________________________________________________




________________________________________num_13_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.6015531734978113
___________________________________________________________________________________




________________________________________cat_13_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5407494809044252
___________________________________________________________________________________




________________________________________bin_13_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5400218543063251
___________________________________________________________________________________




In [82]:
display(all_results2)
features balanced accuracy accuracy roc_auc f1 (response) f1 (equivocal) f1 (no response) f1 (lithium) f1 (olanzapine) f1 (lithium > 2y) f1 (other) f1 (weighted avg)
0 34_resp 0.511033 0.545266 0.644169 0.059306 NaN 0.700162 NaN NaN NaN NaN 0.402578
1 num_34_resp 0.577364 0.594115 0.639054 0.439291 NaN 0.681939 NaN NaN NaN NaN 0.569265
2 cat_34_resp 0.500000 0.535647 0.578502 0.000000 NaN 0.697617 NaN NaN NaN NaN 0.373677
3 bin_34_resp 0.500000 0.535647 0.575889 0.000000 NaN 0.697617 NaN NaN NaN NaN 0.373677
4 lit_34_resp 0.606948 0.610132 0.630076 0.644272 NaN 0.568743 NaN NaN NaN NaN 0.610691
5 ola_34_resp 0.500000 0.598588 0.684015 0.000000 NaN 0.748896 NaN NaN NaN NaN 0.448280
6 34_exp 0.508013 0.424450 0.582772 NaN NaN NaN 0.584895 0.061834 NaN NaN 0.277428
7 num_34_exp 0.518703 0.436882 0.552708 NaN NaN NaN 0.590377 0.099414 NaN NaN 0.301778
8 cat_34_exp 0.509347 0.429551 0.560326 NaN NaN NaN 0.582039 0.101882 NaN NaN 0.299792
9 bin_34_exp 0.514446 0.440708 0.574992 NaN NaN NaN 0.579307 0.165914 NaN NaN 0.336305
10 34_multi 0.345160 0.469238 0.611354 0.093006 0.0 0.630719 NaN NaN NaN NaN 0.324555
11 num_34_multi 0.384728 0.505738 0.605525 0.421562 0.0 0.614419 NaN NaN NaN NaN 0.447623
12 cat_34_multi 0.333333 0.456009 0.561563 0.000000 0.0 0.626382 NaN NaN NaN NaN 0.285636
13 bin_34_multi 0.333333 0.456009 0.562549 0.000000 0.0 0.626382 NaN NaN NaN NaN 0.285636
14 lit_34_multi 0.402939 0.512112 0.600502 0.597719 0.0 0.505909 NaN NaN NaN NaN 0.467889
15 ola_34_multi 0.333333 0.510501 0.642024 0.000000 0.0 0.675936 NaN NaN NaN NaN 0.345066
16 34_lithium2y 0.500000 0.801084 0.621125 NaN NaN NaN NaN NaN 0.0 0.889558 0.712610
17 num_34_lithium2y 0.500000 0.801084 0.601345 NaN NaN NaN NaN NaN 0.0 0.889558 0.712610
18 cat_34_lithium2y 0.500000 0.801084 0.590466 NaN NaN NaN NaN NaN 0.0 0.889558 0.712610
19 bin_34_lithium2y 0.500000 0.801084 0.578624 NaN NaN NaN NaN NaN 0.0 0.889558 0.712610
20 37_resp 0.584509 0.605998 0.658012 0.400230 NaN 0.706642 NaN NaN NaN NaN 0.564359
21 num_37_resp 0.602508 0.618069 0.656649 0.483023 NaN 0.697174 NaN NaN NaN NaN 0.597732
22 cat_37_resp 0.500000 0.535647 0.579403 0.000000 NaN 0.697617 NaN NaN NaN NaN 0.373677
23 bin_37_resp 0.500000 0.535647 0.577717 0.000000 NaN 0.697617 NaN NaN NaN NaN 0.373677
24 lit_37_resp 0.601644 0.606671 0.635831 0.646293 NaN 0.557052 NaN NaN NaN NaN 0.606615
25 ola_37_resp 0.500000 0.598588 0.692510 0.000000 NaN 0.748896 NaN NaN NaN NaN 0.448280
26 37_exp 0.512525 0.433695 0.581302 NaN NaN NaN 0.583226 0.116828 NaN NaN 0.309067
27 num_37_exp 0.519851 0.443258 0.565083 NaN NaN NaN 0.585990 0.150328 NaN NaN 0.329898
28 cat_37_exp 0.511129 0.431782 0.558959 NaN NaN NaN 0.582797 0.109418 NaN NaN 0.304534
29 bin_37_exp 0.512943 0.436564 0.568638 NaN NaN NaN 0.581012 0.140112 NaN NaN 0.321841
30 37_multi 0.387401 0.513548 0.620518 0.379188 0.0 0.634992 NaN NaN NaN NaN 0.440174
31 num_37_multi 0.401444 0.526458 0.617325 0.458353 0.0 0.630303 NaN NaN NaN NaN 0.469479
32 cat_37_multi 0.333333 0.456009 0.563021 0.000000 0.0 0.626382 NaN NaN NaN NaN 0.285636
33 bin_37_multi 0.333333 0.456009 0.563653 0.000000 0.0 0.626451 NaN NaN NaN NaN 0.285667
34 lit_37_multi 0.401859 0.511058 0.606264 0.596645 0.0 0.504175 NaN NaN NaN NaN 0.466740
35 ola_37_multi 0.333333 0.510501 0.647999 0.000000 0.0 0.675936 NaN NaN NaN NaN 0.345066
36 37_lithium2y 0.500000 0.801084 0.619922 NaN NaN NaN NaN NaN 0.0 0.889558 0.712610
37 num_37_lithium2y 0.500000 0.801084 0.609395 NaN NaN NaN NaN NaN 0.0 0.889558 0.712610
38 cat_37_lithium2y 0.500000 0.801084 0.588013 NaN NaN NaN NaN NaN 0.0 0.889558 0.712610
39 bin_37_lithium2y 0.500000 0.801084 0.577890 NaN NaN NaN NaN NaN 0.0 0.889558 0.712610
40 13_resp 0.547752 0.572614 0.637369 0.301910 NaN 0.692036 NaN NaN NaN NaN 0.510880
41 num_13_resp 0.578894 0.595436 0.640128 0.443291 NaN 0.682269 NaN NaN NaN NaN 0.571299
42 cat_13_resp 0.500000 0.535647 0.535855 0.000000 NaN 0.697617 NaN NaN NaN NaN 0.373677
43 bin_13_resp 0.500000 0.535647 0.510002 0.000000 NaN 0.697617 NaN NaN NaN NaN 0.373677
44 lit_13_resp 0.599596 0.606042 0.620413 0.649692 NaN 0.549964 NaN NaN NaN NaN 0.605351
45 ola_13_resp 0.500000 0.598588 0.674894 0.000000 NaN 0.748896 NaN NaN NaN NaN 0.448280
46 13_exp 0.509590 0.424131 0.552450 NaN NaN NaN 0.587792 0.044938 NaN NaN 0.268690
47 num_13_exp 0.518296 0.436404 0.550842 NaN NaN NaN 0.590172 0.097959 NaN NaN 0.300838
48 cat_13_exp 0.500000 0.412177 0.521992 NaN NaN NaN 0.583747 0.000000 NaN NaN 0.240607
49 bin_13_exp 0.500000 0.412177 0.573732 NaN NaN NaN 0.583747 0.000000 NaN NaN 0.240607
50 13_multi 0.364800 0.487727 0.605132 0.285375 0.0 0.624471 NaN NaN NaN NaN 0.398113
51 num_13_multi 0.385811 0.507810 0.604780 0.417036 0.0 0.618374 NaN NaN NaN NaN 0.447629
52 cat_13_multi 0.333333 0.456009 0.530453 0.000000 0.0 0.626382 NaN NaN NaN NaN 0.285636
53 bin_13_multi 0.333333 0.456009 0.510864 0.000000 0.0 0.626382 NaN NaN NaN NaN 0.285636
54 lit_13_multi 0.399934 0.509215 0.593114 0.596050 0.0 0.499676 NaN NaN NaN NaN 0.464782
55 ola_13_multi 0.333333 0.510501 0.632558 0.000000 0.0 0.675936 NaN NaN NaN NaN 0.345066
56 13_lithium2y 0.500000 0.801084 0.605125 NaN NaN NaN NaN NaN 0.0 0.889558 0.712610
57 num_13_lithium2y 0.500000 0.801084 0.601553 NaN NaN NaN NaN NaN 0.0 0.889558 0.712610
58 cat_13_lithium2y 0.500000 0.801084 0.540749 NaN NaN NaN NaN NaN 0.0 0.889558 0.712610
59 bin_13_lithium2y 0.500000 0.801084 0.540022 NaN NaN NaN NaN NaN 0.0 0.889558 0.712610

Feature Importance

In [60]:
for v in clf2_multi_dict.keys():
    importances = clf2_multi_dict[v].feature_importances_
    indices = np.argsort(importances)
    features = X_dict[v].columns
    plt.figure(1, figsize=(8, 10))
    plt.clf()
    plt.title('Feature Importances ' + v)
    plt.barh(range(len(indices)), importances[indices], color='b', align='center')
    plt.yticks(range(len(indices)), [features[i] for i in indices])
    plt.xlabel('Relative Importance')
    plt.show()

SHAP

In [61]:
import shap
for v in clf2_multi_dict.keys():
    print(40*'_' + v + 40*'_')
    shap_values = shap.TreeExplainer(clf2_multi_dict[v]).shap_values(X_dict[v])
    shap.summary_plot(shap_values, X_dict[v], plot_type="bar")
    print(83*'_')
    print('\n\n\n')
________________________________________34_resp________________________________________
___________________________________________________________________________________




________________________________________num_34_resp________________________________________
___________________________________________________________________________________




________________________________________cat_34_resp________________________________________
___________________________________________________________________________________




________________________________________bin_34_resp________________________________________
___________________________________________________________________________________




________________________________________lit_34_resp________________________________________
___________________________________________________________________________________




________________________________________ola_34_resp________________________________________
___________________________________________________________________________________




________________________________________34_exp________________________________________
___________________________________________________________________________________




________________________________________num_34_exp________________________________________
___________________________________________________________________________________




________________________________________cat_34_exp________________________________________
___________________________________________________________________________________




________________________________________bin_34_exp________________________________________
___________________________________________________________________________________




________________________________________34_multi________________________________________
___________________________________________________________________________________




________________________________________num_34_multi________________________________________
___________________________________________________________________________________




________________________________________cat_34_multi________________________________________
___________________________________________________________________________________




________________________________________bin_34_multi________________________________________
___________________________________________________________________________________




________________________________________lit_34_multi________________________________________
___________________________________________________________________________________




________________________________________ola_34_multi________________________________________
___________________________________________________________________________________




________________________________________34_lithium2y________________________________________
___________________________________________________________________________________




________________________________________num_34_lithium2y________________________________________
___________________________________________________________________________________




________________________________________cat_34_lithium2y________________________________________
___________________________________________________________________________________




________________________________________bin_34_lithium2y________________________________________
___________________________________________________________________________________




________________________________________37_resp________________________________________
___________________________________________________________________________________




________________________________________num_37_resp________________________________________
___________________________________________________________________________________




________________________________________cat_37_resp________________________________________
___________________________________________________________________________________




________________________________________bin_37_resp________________________________________
___________________________________________________________________________________




________________________________________lit_37_resp________________________________________
___________________________________________________________________________________




________________________________________ola_37_resp________________________________________
___________________________________________________________________________________




________________________________________37_exp________________________________________
___________________________________________________________________________________




________________________________________num_37_exp________________________________________
___________________________________________________________________________________




________________________________________cat_37_exp________________________________________
___________________________________________________________________________________




________________________________________bin_37_exp________________________________________
___________________________________________________________________________________




________________________________________37_multi________________________________________
___________________________________________________________________________________




________________________________________num_37_multi________________________________________
___________________________________________________________________________________




________________________________________cat_37_multi________________________________________
___________________________________________________________________________________




________________________________________bin_37_multi________________________________________
___________________________________________________________________________________




________________________________________lit_37_multi________________________________________
___________________________________________________________________________________




________________________________________ola_37_multi________________________________________
___________________________________________________________________________________




________________________________________37_lithium2y________________________________________
___________________________________________________________________________________




________________________________________num_37_lithium2y________________________________________
___________________________________________________________________________________




________________________________________cat_37_lithium2y________________________________________
___________________________________________________________________________________




________________________________________bin_37_lithium2y________________________________________
___________________________________________________________________________________




________________________________________13_resp________________________________________
___________________________________________________________________________________




________________________________________num_13_resp________________________________________
___________________________________________________________________________________




________________________________________cat_13_resp________________________________________
___________________________________________________________________________________




________________________________________bin_13_resp________________________________________
___________________________________________________________________________________




________________________________________lit_13_resp________________________________________
___________________________________________________________________________________




________________________________________ola_13_resp________________________________________
___________________________________________________________________________________




________________________________________13_exp________________________________________
___________________________________________________________________________________




________________________________________num_13_exp________________________________________
___________________________________________________________________________________




________________________________________cat_13_exp________________________________________
___________________________________________________________________________________




________________________________________bin_13_exp________________________________________
___________________________________________________________________________________




________________________________________13_multi________________________________________
___________________________________________________________________________________




________________________________________num_13_multi________________________________________
___________________________________________________________________________________




________________________________________cat_13_multi________________________________________
___________________________________________________________________________________




________________________________________bin_13_multi________________________________________
___________________________________________________________________________________




________________________________________lit_13_multi________________________________________
___________________________________________________________________________________




________________________________________ola_13_multi________________________________________
___________________________________________________________________________________




________________________________________13_lithium2y________________________________________
___________________________________________________________________________________




________________________________________num_13_lithium2y________________________________________
___________________________________________________________________________________




________________________________________cat_13_lithium2y________________________________________
___________________________________________________________________________________




________________________________________bin_13_lithium2y________________________________________
___________________________________________________________________________________




Naive Bayes

In [52]:
from sklearn.naive_bayes import GaussianNB, BernoulliNB
from mixed_naive_bayes import MixedNB
clf3_dict = dict()


for v in [feats for feats in X_dict if ('_13' not in feats) and ('_3' not in feats)]:
    print(v)
    gnb = GaussianNB()
    v_num = 'num_' + v
    clf3_dict[v_num] = gnb.fit(X_dict[v_num], y_dict[v_num])

    bnb = BernoulliNB()
    v_bin = 'bin_' + v
    clf3_dict[v_bin] = bnb.fit(X_dict[v_bin], y_dict[v_bin])

    bin_features = [X_dict[v].columns.get_loc(c) for c in list(X_dict['bin_' + v].columns)]
    mnb = MixedNB(categorical_features=bin_features)
    clf3_dict[v] = mnb.fit(np.array(X_dict[v]), np.array(y_dict[v]))
34_resp
[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
34_exp
[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
34_multi
[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
34_lithium2y
[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
37_resp
[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
37_exp
[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
37_multi
[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
37_lithium2y
[2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]
13_resp
[2 2 2 2 2 2 2 2]
13_exp
[2 2 2 2 2 2 2 2]
13_multi
[2 2 2 2 2 2 2 2]
13_lithium2y
[2 2 2 2 2 2 2 2]
In [93]:
all_results3 = evaluate(clf3_dict, x_test, y_test)
________________________________________num_34_resp________________________________________
Balanced Accuracy: 0.5857668962597682
ROC_AUC score: 0.6380105061726984
___________________________________________________________________________________




________________________________________bin_34_resp________________________________________
Balanced Accuracy: 0.5367480063157173
ROC_AUC score: 0.5760540497248314
___________________________________________________________________________________




________________________________________34_resp________________________________________
Balanced Accuracy: 0.5952593505795128
ROC_AUC score: 0.50115180318303
___________________________________________________________________________________




________________________________________num_34_exp________________________________________
Balanced Accuracy: 0.5224358006485783
ROC_AUC score: 0.540462221070238
___________________________________________________________________________________




________________________________________bin_34_exp________________________________________
Balanced Accuracy: 0.5493286896068099
ROC_AUC score: 0.5724976219355683
___________________________________________________________________________________




________________________________________34_exp________________________________________
Balanced Accuracy: 0.5580377738968214
ROC_AUC score: 0.49567450211635156
___________________________________________________________________________________




________________________________________num_34_multi________________________________________
Balanced Accuracy: 0.39002599299933854
ROC_AUC score: 0.6033948361067034
___________________________________________________________________________________




________________________________________bin_34_multi________________________________________
Balanced Accuracy: 0.3575640114310101
ROC_AUC score: 0.5615403377461362
___________________________________________________________________________________




________________________________________34_multi________________________________________
Balanced Accuracy: 0.3965460299008324
ROC_AUC score: None
___________________________________________________________________________________




________________________________________num_34_lithium2y________________________________________
Balanced Accuracy: 0.5325429561153795
ROC_AUC score: 0.5977664541818442
___________________________________________________________________________________




________________________________________bin_34_lithium2y________________________________________
Balanced Accuracy: 0.5004006410256411
ROC_AUC score: 0.584060401935576
___________________________________________________________________________________




________________________________________34_lithium2y________________________________________
Balanced Accuracy: 0.5462642336771864
ROC_AUC score: 0.5140489008438173
___________________________________________________________________________________




________________________________________num_37_resp________________________________________
Balanced Accuracy: 0.5951753984508186
ROC_AUC score: 0.6473412775597532
___________________________________________________________________________________




________________________________________bin_37_resp________________________________________
Balanced Accuracy: 0.5367480063157173
ROC_AUC score: 0.5760540497248314
___________________________________________________________________________________




________________________________________37_resp________________________________________
Balanced Accuracy: 0.6166868800128145
ROC_AUC score: 0.5057420681685564
___________________________________________________________________________________




________________________________________num_37_exp________________________________________
Balanced Accuracy: 0.5224040302110647
ROC_AUC score: 0.5437385605454366
___________________________________________________________________________________




________________________________________bin_37_exp________________________________________
Balanced Accuracy: 0.5493286896068099
ROC_AUC score: 0.5724976219355683
___________________________________________________________________________________




________________________________________37_exp________________________________________
Balanced Accuracy: 0.5421201555849703
ROC_AUC score: 0.5031892066911268
___________________________________________________________________________________




________________________________________num_37_multi________________________________________
Balanced Accuracy: 0.39620903166493027
ROC_AUC score: 0.610487813737447
___________________________________________________________________________________




________________________________________bin_37_multi________________________________________
Balanced Accuracy: 0.3575640114310101
ROC_AUC score: 0.5615403377461362
___________________________________________________________________________________




________________________________________37_multi________________________________________
Balanced Accuracy: 0.4107821949731159
ROC_AUC score: None
___________________________________________________________________________________




________________________________________num_37_lithium2y________________________________________
Balanced Accuracy: 0.5497151989143633
ROC_AUC score: 0.6024776929198934
___________________________________________________________________________________




________________________________________bin_37_lithium2y________________________________________
Balanced Accuracy: 0.5004006410256411
ROC_AUC score: 0.584060401935576
___________________________________________________________________________________




________________________________________37_lithium2y________________________________________
Balanced Accuracy: 0.5824579175467058
ROC_AUC score: 0.5084447093064781
___________________________________________________________________________________




________________________________________num_13_resp________________________________________
Balanced Accuracy: 0.5857668962597682
ROC_AUC score: 0.6379753235088843
___________________________________________________________________________________




________________________________________bin_13_resp________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5130504084621457
___________________________________________________________________________________




________________________________________13_resp________________________________________
Balanced Accuracy: 0.5832612041052162
ROC_AUC score: 0.4957072144483473
___________________________________________________________________________________




________________________________________num_13_exp________________________________________
Balanced Accuracy: 0.5224358006485783
ROC_AUC score: 0.5404592851882236
___________________________________________________________________________________




________________________________________bin_13_exp________________________________________
Balanced Accuracy: 0.5025319885316061
ROC_AUC score: 0.5606965296196942
___________________________________________________________________________________




________________________________________13_exp________________________________________
Balanced Accuracy: 0.539042931822109
ROC_AUC score: 0.47635325287338964
___________________________________________________________________________________




________________________________________num_13_multi________________________________________
Balanced Accuracy: 0.39002599299933854
ROC_AUC score: 0.6033941180690108
___________________________________________________________________________________




________________________________________bin_13_multi________________________________________
Balanced Accuracy: 0.3333333333333333
ROC_AUC score: 0.5108453613008589
___________________________________________________________________________________




________________________________________13_multi________________________________________
Balanced Accuracy: 0.38824810288881545
ROC_AUC score: None
___________________________________________________________________________________




________________________________________num_13_lithium2y________________________________________
Balanced Accuracy: 0.5325429561153795
ROC_AUC score: 0.5977613525054333
___________________________________________________________________________________




________________________________________bin_13_lithium2y________________________________________
Balanced Accuracy: 0.5
ROC_AUC score: 0.5417565837134082
___________________________________________________________________________________




________________________________________13_lithium2y________________________________________
Balanced Accuracy: 0.5359470497005316
ROC_AUC score: 0.5166446975726223
___________________________________________________________________________________




roc_auc_score returns an error for 34_multi. Looks like predict_proba() for MixedNB returns probas that don't sum up to 1. Strange.

In [91]:
v = '34_multi'
v2 = v
y_score = clf3_dict[v].predict_proba(x_test[v2])
print(y_score[0])
roc_auc_score(y_test[v2], y_score, average='weighted', multi_class='ovr')
#result = results(clf3_dict[v], x_test[v2], y_test[v2], v)
[4.16702351e-06 1.29765617e-06 3.37361391e-06]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [91], in <cell line: 5>()
      3 y_score = clf3_dict[v].predict_proba(x_test[v2])
      4 print(y_score[0])
----> 5 roc_auc_score(y_test[v2], y_score, average='weighted', multi_class='ovr')

File ~/GoogleDrive/sics/projects/ucl/lithium/venv/lib/python3.9/site-packages/sklearn/metrics/_ranking.py:561, in roc_auc_score(y_true, y_score, average, sample_weight, max_fpr, multi_class, labels)
    559     if multi_class == "raise":
    560         raise ValueError("multi_class must be in ('ovo', 'ovr')")
--> 561     return _multiclass_roc_auc_score(
    562         y_true, y_score, labels, multi_class, average, sample_weight
    563     )
    564 elif y_type == "binary":
    565     labels = np.unique(y_true)

File ~/GoogleDrive/sics/projects/ucl/lithium/venv/lib/python3.9/site-packages/sklearn/metrics/_ranking.py:628, in _multiclass_roc_auc_score(y_true, y_score, labels, multi_class, average, sample_weight)
    626 # validation of the input y_score
    627 if not np.allclose(1, y_score.sum(axis=1)):
--> 628     raise ValueError(
    629         "Target scores need to be probabilities for multiclass "
    630         "roc_auc, i.e. they should sum up to 1.0 over classes"
    631     )
    633 # validation for multiclass parameter specifications
    634 average_options = ("macro", "weighted")

ValueError: Target scores need to be probabilities for multiclass roc_auc, i.e. they should sum up to 1.0 over classes
In [94]:
display(all_results3)
features balanced accuracy accuracy roc_auc f1 (response) f1 (equivocal) f1 (no response) f1 (lithium) f1 (olanzapine) f1 (lithium > 2y) f1 (other) f1 (weighted avg)
0 num_34_resp 0.585767 0.599208 0.638011 0.479294 NaN 0.674230 NaN NaN NaN NaN 0.583711
1 bin_34_resp 0.536748 0.564127 0.576054 0.245511 NaN 0.693542 NaN NaN NaN NaN 0.485497
2 34_resp 0.595259 0.608827 0.501152 0.490167 NaN 0.682681 NaN NaN NaN NaN 0.593286
3 num_34_exp 0.522436 0.472107 0.540462 NaN NaN NaN 0.558164 0.344418 NaN NaN 0.432519
4 bin_34_exp 0.549329 0.529806 0.572498 NaN NaN NaN 0.536601 0.522808 NaN NaN 0.528493
5 34_exp 0.558038 0.546701 0.495675 NaN NaN NaN 0.531003 0.561382 NaN NaN 0.548860
6 num_34_multi 0.390026 0.510201 0.603395 0.451128 0.0 0.610942 NaN NaN NaN NaN 0.457781
7 bin_34_multi 0.357564 0.480236 0.56154 0.236620 0.0 0.625000 NaN NaN NaN NaN 0.378990
8 34_multi 0.396546 0.518648 None 0.462348 0.0 0.618352 NaN NaN NaN NaN 0.465616
9 num_34_lithium2y 0.532543 0.762512 0.597766 NaN NaN NaN NaN NaN 0.201501 0.860513 0.729425
10 bin_34_lithium2y 0.500401 0.801243 0.58406 NaN NaN NaN NaN NaN 0.001601 0.889636 0.712992
11 34_lithium2y 0.546264 0.768091 0.514049 NaN NaN NaN NaN NaN 0.233807 0.863367 0.738138
12 num_37_resp 0.595175 0.607318 0.647341 0.501198 NaN 0.676205 NaN NaN NaN NaN 0.594940
13 bin_37_resp 0.536748 0.564127 0.576054 0.245511 NaN 0.693542 NaN NaN NaN NaN 0.485497
14 37_resp 0.616687 0.619012 0.505742 0.587418 NaN 0.646111 NaN NaN NaN NaN 0.618857
15 num_37_exp 0.522404 0.493465 0.543739 NaN NaN NaN 0.527926 0.453576 NaN NaN 0.484222
16 bin_37_exp 0.549329 0.529806 0.572498 NaN NaN NaN 0.536601 0.522808 NaN NaN 0.528493
17 37_exp 0.542120 0.516576 0.503189 NaN NaN NaN 0.539687 0.491022 NaN NaN 0.511081
18 num_37_multi 0.396209 0.517055 0.610488 0.470327 0.0 0.613356 NaN NaN NaN NaN 0.466507
19 bin_37_multi 0.357564 0.480236 0.56154 0.236620 0.0 0.625000 NaN NaN NaN NaN 0.378990
20 37_multi 0.410782 0.527574 None 0.545080 0.0 0.590033 NaN NaN NaN NaN 0.485564
21 num_37_lithium2y 0.549715 0.743704 0.602478 NaN NaN NaN NaN NaN 0.261029 0.844967 0.728813
22 bin_37_lithium2y 0.500401 0.801243 0.58406 NaN NaN NaN NaN NaN 0.001601 0.889636 0.712992
23 37_lithium2y 0.582458 0.687600 0.508445 NaN NaN NaN NaN NaN 0.341840 0.795193 0.705014
24 num_13_resp 0.585767 0.599208 0.637975 0.479294 NaN 0.674230 NaN NaN NaN NaN 0.583711
25 bin_13_resp 0.500000 0.535647 0.51305 0.000000 NaN 0.697617 NaN NaN NaN NaN 0.373677
26 13_resp 0.583261 0.596379 0.495707 0.478811 NaN 0.670668 NaN NaN NaN NaN 0.581579
27 num_13_exp 0.522436 0.472107 0.540459 NaN NaN NaN 0.558164 0.344418 NaN NaN 0.432519
28 bin_13_exp 0.502532 0.418075 0.560697 NaN NaN NaN 0.582122 0.041984 NaN NaN 0.264616
29 13_exp 0.539043 0.493465 0.476353 NaN NaN NaN 0.565134 0.393511 NaN NaN 0.464250
30 num_13_multi 0.390026 0.510201 0.603394 0.451128 0.0 0.610942 NaN NaN NaN NaN 0.457781
31 bin_13_multi 0.333333 0.456009 0.510845 0.000000 0.0 0.626382 NaN NaN NaN NaN 0.285636
32 13_multi 0.388248 0.507651 None 0.450124 0.0 0.607827 NaN NaN NaN NaN 0.455962
33 num_13_lithium2y 0.532543 0.762512 0.597761 NaN NaN NaN NaN NaN 0.201501 0.860513 0.729425
34 bin_13_lithium2y 0.500000 0.801084 0.541757 NaN NaN NaN NaN NaN 0.000000 0.889558 0.712610
35 13_lithium2y 0.535947 0.764106 0.516645 NaN NaN NaN NaN NaN 0.209402 0.861371 0.731684

Hyperparameter Tuning

LightGBM

In [60]:
import lightgbm as lgb
from sklearn.model_selection import cross_val_score
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials, space_eval
from hyperopt.pyll import scope
import time

param_hyperopt= {
    'learning_rate': hp.loguniform('learning_rate', np.log(0.01), np.log(1)),
    'max_depth': scope.int(hp.quniform('max_depth', 5, 15, 1)),
    'n_estimators': scope.int(hp.quniform('n_estimators', 5, 35, 1)),
    'num_leaves': scope.int(hp.quniform('num_leaves', 5, 50, 1)),
    'boosting_type': hp.choice('boosting_type', ['gbdt', 'dart']),
    'colsample_bytree': hp.uniform('colsample_by_tree', 0.6, 1.0),
    'reg_lambda': hp.uniform('reg_lambda', 0.0, 1.0),
}

def hyperopt_lgbm(param_space, X_train, y_train, X_test, y_test, num_eval, metric):
    
    start = time.time()
    
    def objective_function(params):
        clf = lgb.LGBMClassifier(**params)
        score = cross_val_score(clf, X_train, y_train, cv=5, scoring=metric).mean()
        return {'loss': -score, 'status': STATUS_OK}

    trials = Trials()
    best_param = fmin(objective_function, 
                      param_space, 
                      algo=tpe.suggest, 
                      max_evals=num_eval, 
                      trials=trials)
    loss = [x['result']['loss'] for x in trials.trials]
    best_param_dict = space_eval(param_space, best_param)
    clf_best = lgb.LGBMClassifier(**best_param_dict)
    clf_best.fit(X_train, y_train)
    
    print("")
    print("##### Results")
    print("Score best parameters: ", min(loss)*-1)
    print("Best parameters: ", best_param_dict)
    print("Test Score: ", clf_best.score(X_test, y_test))
    print("Time elapsed: ", time.time() - start)
    print("Parameter combinations evaluated: ", num_eval)

    
    return clf_best, best_param_dict
In [64]:
results_hyperopt_lgbm = dict()
best_params_dict_lgbm = dict()
for target in ['lithium2y']:
    for feature_number in ['13']:
        feature_set = feature_number + '_' + target
        print('\n' + 10*'_' + feature_set + 10*'_')
        for metric in ['balanced_accuracy']:
            print('\n\t----> metric:',  metric)
            results_hyperopt_lgbm[feature_set + '_' + metric], best_params_dict_lgbm[feature_set + '_' + metric] = hyperopt_lgbm(param_hyperopt,
                                                                                                                     x_train[feature_set],
                                                                                                                     y_train[feature_set],
                                                                                                                     x_test[feature_set],
                                                                                                                     y_test[feature_set],
                                                                                                                     1000, metric)
__________13_lithium2y__________

	----> metric: balanced_accuracy
100%|██████████| 1000/1000 [22:38<00:00,  1.36s/trial, best loss: -0.5348214119885094]

##### Results
Score best parameters:  0.5348214119885094
Best parameters:  {'boosting_type': 'gbdt', 'colsample_bytree': 0.600831897129334, 'learning_rate': 0.9933016655780533, 'max_depth': 12, 'n_estimators': 34, 'num_leaves': 50, 'reg_lambda': 0.6221105546939463}
Test Score:  0.7496015301243226
Time elapsed:  1359.1855659484863
Parameter combinations evaluated:  1000
In [95]:
all_results_hyper_lgbm = evaluate(results_hyperopt_lgbm, x_test, y_test)
________________________________________13_lithium2y_balanced_accuracy________________________________________
Balanced Accuracy: 0.5271952832450744
ROC_AUC score: 0.5663205179221893
___________________________________________________________________________________




In [96]:
display(all_results_hyper_lgbm)
features balanced accuracy accuracy roc_auc f1 (response) f1 (equivocal) f1 (no response) f1 (lithium) f1 (olanzapine) f1 (lithium > 2y) f1 (other) f1 (weighted avg)
0 13_lithium2y_balanced_accuracy 0.527195 0.749602 0.566321 NaN NaN NaN NaN NaN 0.200509 0.851554 0.722051

Random Forest

In [97]:
from sklearn.model_selection import cross_val_score
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials, space_eval
from hyperopt.pyll import scope
import time

param_hyperopt= {
    'max_depth': scope.int(hp.quniform('max_depth', 5, 15, 1)),
    'n_estimators': scope.int(hp.quniform('n_estimators', 5, 35, 1)),
    'min_samples_split':hp.uniform('min_samples_split',0,1),
    'min_samples_leaf':hp.randint('min_samples_leaf',1,10),
    'criterion':hp.choice('criterion', ['gini','entropy']),
    'max_features':hp.choice('max_features',['sqrt', 'log2'])
}


def hyperopt_rf(param_space, X_train, y_train, X_test, y_test, num_eval, metric):
    start = time.time()

    def objective_function(params):
        clf = RandomForestClassifier(**params, random_state=42)

        clf.fit(X_train, y_train)
        score = cross_val_score(clf, X_train, y_train, cv=5, scoring=metric).mean()
        return {'loss': -score, 'status': STATUS_OK}



    trials = Trials()
    best_param = fmin(objective_function,
                  param_space,
                  algo=tpe.suggest,
                  max_evals=num_eval,
                  trials=trials)

    best_param_dict = space_eval(param_space, best_param)
    loss = [x['result']['loss'] for x in trials.trials]
    clf_best = RandomForestClassifier(**best_param_dict)
    clf_best.fit(X_train, y_train)

    print("##### Results")
    print("Score best parameters: ", min(loss)*-1)
    print("Best parameters: ", best_param)
    print("Test Score: ", clf_best.score(X_test, y_test))
    print("Time elapsed: ", time.time() - start)
    print("Parameter combinations evaluated: ", num_eval)
    #results(clf_best, X_test, y_test)

    return clf_best, best_param_dict
In [98]:
results_hyperopt = dict()
best_params_dict = dict()
for target in ['lithium2y']:
    for feature_number in ['13']:
        feature_set = feature_number + '_' + target
        print('\n' + 10*'_' + feature_set + 10*'_')
        for metric in ['balanced_accuracy']:
            print('\n\t----> metric:',  metric)
            results_hyperopt[feature_set + '_' + metric], best_params_dict[feature_set + '_' + metric] = hyperopt_rf(param_hyperopt,
                                                                                                                     x_train[feature_set],
                                                                                                                     y_train[feature_set],
                                                                                                                     x_test[feature_set],
                                                                                                                     y_test[feature_set],
                                                                                                                     1000, metric)
__________13_lithium2y__________

	----> metric: balanced_accuracy
100%|██████████| 1000/1000 [19:17<00:00,  1.16s/trial, best loss: -0.5106134589427425]
##### Results
Score best parameters:  0.5106134589427425
Best parameters:  {'criterion': 0, 'max_depth': 14.0, 'max_features': 0, 'min_samples_leaf': 2, 'min_samples_split': 0.00035845671793269197, 'n_estimators': 16.0}
Test Score:  0.7975773031558814
Time elapsed:  1158.2907378673553
Parameter combinations evaluated:  1000
In [99]:
all_results_hyper_rf = evaluate(results_hyperopt, x_test, y_test)
________________________________________13_lithium2y_balanced_accuracy________________________________________
Balanced Accuracy: 0.5029310725254319
ROC_AUC score: 0.589657897522626
___________________________________________________________________________________




In [100]:
display(all_results_hyper_rf)
features balanced accuracy accuracy roc_auc f1 (response) f1 (equivocal) f1 (no response) f1 (lithium) f1 (olanzapine) f1 (lithium > 2y) f1 (other) f1 (weighted avg)
0 13_lithium2y_balanced_accuracy 0.502931 0.797577 0.589658 NaN NaN NaN NaN NaN 0.026074 0.887051 0.715789
In [71]:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC

from umap import UMAP

# Classification with a linear SVM
svc = LinearSVC(dual=False, random_state=123)
params_grid = {"C": [10 ** k for k in range(-3, 4)]}
clf = GridSearchCV(svc, params_grid, scoring='balanced_accuracy')
feature_set = '13_lithium2y'
clf.fit(x_train[feature_set], y_train[feature_set])
print(
    "Balanced accuracy on the test set with raw data: {:.3f}".format(clf.score(x_test[feature_set], y_test[feature_set]))
)

# Transformation with UMAP followed by classification with a linear SVM
umap = UMAP(random_state=456)
pipeline = Pipeline([("umap", umap), ("svc", svc)])
params_grid_pipeline = {
    "umap__n_neighbors": [5, 20],
    "umap__n_components": [15, 25, 50],
    "svc__C": [10 ** k for k in range(-3, 4)],
}


clf_pipeline = GridSearchCV(pipeline, params_grid_pipeline, scoring='balanced_accuracy')
clf_pipeline.fit(x_train[feature_set], y_train[feature_set])
print(
    "Balanced accuracy on the test set with UMAP transformation: {:.3f}".format(
        clf_pipeline.score(x_test[feature_set], y_test[feature_set])
    )
)
Balanced accuracy on the test set with raw data: 0.504
OMP: Info #270: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
Balanced accuracy on the test set with UMAP transformation: 0.513
In [74]:
print("Balanced accuracy on the test set with raw data: {:.6f}".format(clf.score(x_test[feature_set], y_test[feature_set])),
      "\nBalanced accuracy on the test set with UMAP transformation: {:.6f}".format(clf_pipeline.score(x_test[feature_set], y_test[feature_set])))
Balanced accuracy on the test set with raw data: 0.504214 
Balanced accuracy on the test set with UMAP transformation: 0.513193
In [ ]: